Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-06-06 11:11:53 -04:00
commit 1a492ad0cc
200 changed files with 8714 additions and 3175 deletions

View file

@ -6,12 +6,12 @@
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import AccessRule, Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety],
deps[Api.tool_runtime],
deps[Api.tool_groups],
policy,
)
await impl.initialize()
return impl

View file

@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
vector_io_api: VectorIO,
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
):
self.agent_id = agent_id
self.agent_config = agent_config
self.inference_api = inference_api
self.safety_api = safety_api
self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.storage = AgentPersistence(agent_id, persistence_store, policy)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at

View file

@ -29,6 +29,7 @@ from llama_stack.apis.agents import (
Session,
Turn,
)
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
Inference,
@ -40,6 +41,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -61,6 +63,7 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
policy: list[AccessRule],
):
self.config = config
self.inference_api = inference_api
@ -71,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
self.policy = policy
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -129,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
policy=self.policy,
)
async def create_agent_session(
@ -324,10 +329,12 @@ class MetaReferenceAgentsImpl(Agents):
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, instructions, previous_response_id, store, stream, temperature, tools
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
)
async def list_openai_responses(

View file

@ -8,7 +8,7 @@ import json
import time
import uuid
from collections.abc import AsyncIterator
from typing import Any, cast
from typing import Any
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
@ -37,6 +37,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference.inference import (
Inference,
@ -50,7 +52,12 @@ from llama_stack.apis.inference.inference import (
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
@ -158,6 +165,21 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
)
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
if not text.format or text.format["type"] == "text":
return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
async def _get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
@ -178,8 +200,8 @@ class ChatCompletionContext(BaseModel):
messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
stream: bool
temperature: float | None
response_format: OpenAIResponseFormatParam
class OpenAIResponsesImpl:
@ -258,37 +280,6 @@ class OpenAIResponsesImpl:
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def _process_response_choices(
self,
chat_response: OpenAIChatCompletion,
ctx: ChatCompletionContext,
tools: list[OpenAIResponseInputTool] | None,
) -> list[OpenAIResponseOutput]:
"""Handle tool execution and response message creation."""
output_messages: list[OpenAIResponseOutput] = []
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if tools[0].type == "function":
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
else:
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
output_messages.extend(tool_messages)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
return output_messages
async def _store_response(
self,
response: OpenAIResponseObject,
@ -331,10 +322,52 @@ class OpenAIResponsesImpl:
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
):
stream = False if stream is None else stream
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
@ -342,7 +375,10 @@ class OpenAIResponsesImpl:
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Tool setup
# Structured outputs
response_format = await _convert_response_text_to_chat_response_format(text)
# Tool setup, TODO: refactor this slightly since this can also yield events
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
@ -354,89 +390,10 @@ class OpenAIResponsesImpl:
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
response_format=response_format,
)
inference_result = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
if stream:
return self._create_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
else:
return await self._create_non_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
async def _create_non_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> OpenAIResponseObject:
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response,
ctx=ctx,
tools=tools,
)
)
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
await self._store_response(
response=response,
input=input,
)
return response
async def _create_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
@ -448,87 +405,144 @@ class OpenAIResponsesImpl:
object="response",
status="in_progress",
output=output_messages.copy(),
text=text,
)
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# For streaming, inference_result is an async iterator of chunks
# Stream chunks and emit delta events as they arrive
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
n_iter = 0
messages = ctx.messages.copy()
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in inference_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response_obj = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response_obj,
ctx=ctx,
tools=tools,
while True:
completion_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.tools,
stream=True,
temperature=ctx.temperature,
response_format=ctx.response_format,
)
)
# Process streaming chunks and build complete response
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
if tool_call.function.arguments:
# Guard against an initial None argument before we concatenate
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
current_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
function_tool_calls = []
non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if _is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# execute non-function tool calls
for tool_call in non_function_tool_calls:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if tool_response_message:
next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
@ -537,18 +551,19 @@ class OpenAIResponsesImpl:
model=model,
object="response",
status="completed",
text=text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
if store:
await self._store_response(
response=final_response,
input=input,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> tuple[
@ -641,49 +656,6 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_and_return_final_output(
self,
choice: OpenAIChoice,
ctx: ChatCompletionContext,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
if not choice.message.tool_calls:
return output_messages
next_turn_messages = ctx.messages.copy()
# Add the assistant message with tool_calls response to the messages list
next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls:
# TODO: telemetry spans for tool calls
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
next_turn_messages.append(further_input)
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=next_turn_messages,
stream=ctx.stream,
temperature=ctx.temperature,
)
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
# Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages
async def _execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
@ -767,5 +739,20 @@ class OpenAIResponsesImpl:
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc)
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message
def _is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False

View file

@ -10,9 +10,10 @@ import uuid
from datetime import datetime, timezone
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -22,7 +23,9 @@ class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
access_attributes: AccessAttributes | None = None
owner: User | None = None
identifier: str | None = None
type: str = "session"
class AgentInfo(AgentConfig):
@ -30,24 +33,27 @@ class AgentInfo(AgentConfig):
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
self.kvstore = kvstore
self.policy = policy
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
auth_attributes = get_auth_attributes()
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
user = get_authenticated_user()
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
owner=user,
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError()
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -73,10 +79,10 @@ class AgentPersistence:
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes"):
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import LocalfsFilesImplConfig
from .files import LocalfsFilesImpl
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]):
impl = LocalfsFilesImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
class LocalfsFilesImplConfig(BaseModel):
storage_dir: str = Field(
description="Directory to store uploaded files",
)
metadata_store: SqlStoreConfig = Field(
description="SQL store configuration for file metadata",
)
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"storage_dir": "${env.FILES_STORAGE_DIR:" + __distro_dir__ + "/files}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="files_metadata.db",
),
}

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from pathlib import Path
from typing import Annotated
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
from .config import LocalfsFilesImplConfig
class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig) -> None:
self.config = config
self.sql_store: SqlStore | None = None
async def initialize(self) -> None:
"""Initialize the files provider by setting up storage directory and metadata database."""
# Create storage directory if it doesn't exist
storage_path = Path(self.config.storage_dir)
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = sqlstore_impl(self.config.metadata_store)
await self.sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
"file_path": ColumnType.STRING, # Path to actual file on disk
},
)
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"
def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID."""
return Path(self.config.storage_dir) / file_id
# OpenAI Files API Implementation
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
file_id = self._generate_file_id()
file_path = self._get_file_path(file_id)
content = await file.read()
file_size = len(content)
with open(file_path, "wb") as f:
f.write(content)
created_at = int(time.time())
expires_at = created_at + self.config.ttl_secs
await self.sql_store.insert(
"openai_files",
{
"id": file_id,
"filename": file.filename or "uploaded_file",
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
"file_path": file_path.as_posix(),
},
)
return OpenAIFileObject(
id=file_id,
filename=file.filename or "uploaded_file",
purpose=purpose,
bytes=file_size,
created_at=created_at,
expires_at=expires_at,
)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
"""Returns a list of files that belong to the user's organization."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# TODO: Implement 'after' pagination properly
if after:
raise NotImplementedError("After pagination not yet implemented")
where = None
if purpose:
where = {"purpose": purpose.value}
rows = await self.sql_store.fetch_all(
"openai_files",
where=where,
order_by=[("created_at", order.value if order else Order.desc.value)],
limit=limit,
)
files = [
OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in rows
]
return ListOpenAIFileResponse(
data=files,
has_more=False, # TODO: Implement proper pagination
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
"""Returns information about a specific file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
"""Delete a file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Delete physical file
file_path = Path(row["file_path"])
if file_path.exists():
file_path.unlink()
# Delete metadata from database
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(
id=file_id,
deleted=True,
)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
"""Returns the contents of the specified file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# Get file metadata
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Read file content
file_path = Path(row["file_path"])
if not file_path.exists():
raise ValueError(f"File content not found on disk: {file_path}")
with open(file_path, "rb") as f:
content = f.read()
# Return as binary response with appropriate content type
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)

View file

@ -40,6 +40,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -410,6 +411,16 @@ class VLLMInferenceImpl(
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,

View file

@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
)
service_name: str = Field(
# service name is always the same, use zero-width space to avoid clutter
default="",
default="\u200b",
description="The service name to use for telemetry",
)
sinks: list[TelemetrySink] = Field(
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:}",
"service_name": "${env.OTEL_SERVICE_NAME:\u200b}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
}

View file

@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:

View file

@ -24,7 +24,7 @@ def available_providers() -> list[ProviderSpec]:
"pandas",
"scikit-learn",
]
+ kvstore_dependencies(),
+ kvstore_dependencies(), # TODO make this dynamic based on the kvstore config
module="llama_stack.providers.inline.agents.meta_reference",
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[

View file

@ -4,8 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack.providers.datatypes import (
Api,
InlineProviderSpec,
ProviderSpec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
def available_providers() -> list[ProviderSpec]:
return []
return [
InlineProviderSpec(
api=Api.files,
provider_type="inline::localfs",
# TODO: make this dynamic according to the sql store type
pip_packages=sql_store_pip_packages,
module="llama_stack.providers.inline.files.localfs",
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
),
]

View file

@ -15,7 +15,6 @@ from llama_stack.providers.datatypes import (
META_REFERENCE_DEPS = [
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",

View file

@ -20,7 +20,6 @@ def available_providers() -> list[ProviderSpec]:
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"blobfile",
"chardet",
"pypdf",
"tqdm",

View file

@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -197,3 +198,13 @@ class BedrockInferenceAdapter(
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -194,3 +195,13 @@ class CerebrasInferenceAdapter(
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -152,3 +153,13 @@ class DatabricksInferenceAdapter(
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -37,6 +37,7 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
@ -254,7 +255,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
"stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
@ -286,6 +287,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,

View file

@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -238,6 +239,16 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,

View file

@ -12,7 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry,
)
model_entries = [
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -32,6 +33,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -76,7 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
from .models import model_entries
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
@ -86,7 +88,7 @@ class OllamaInferenceAdapter(
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_entries)
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.url = url
@property
@ -343,21 +345,27 @@ class OllamaInferenceAdapter(
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
# TODO: you should pull here only if the model is not found in a list
response = await self.client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
available_models = [m.model for m in response.models]
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
if provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
@ -370,6 +378,16 @@ class OllamaInferenceAdapter(
return model
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
@ -469,7 +487,25 @@ class OllamaInferenceAdapter(
top_p=top_p,
user=user,
)
return await self.openai_client.chat.completions.create(**params) # type: ignore
response = await self.openai_client.chat.completions.create(**params)
return await self._adjust_ollama_chat_completion_response_ids(response)
async def _adjust_ollama_chat_completion_response_ids(
self,
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
id = f"chatcmpl-{uuid.uuid4()}"
if isinstance(response, AsyncIterator):
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
async for chunk in response:
chunk.id = id
yield chunk
return stream_with_chunk_ids()
else:
response.id = id
return response
async def batch_completion(
self,

View file

@ -14,6 +14,9 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
@ -38,6 +41,7 @@ logger = logging.getLogger(__name__)
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI |
# | openai_chat_completion | AsyncOpenAI |
# | openai_embeddings | AsyncOpenAI |
#
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: OpenAIConfig) -> None:
@ -171,3 +175,51 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
user=user,
)
return await self._openai_client.chat.completions.create(**params)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
# Prepare parameters for OpenAI embeddings API
params = {
"model": model_id,
"input": input,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
# Call OpenAI embeddings API
response = await self._openai_client.embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)

View file

@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -210,6 +211,16 @@ class PassthroughInferenceAdapter(Inference):
task_type=task_type,
)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,

View file

@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -134,3 +135,13 @@ class RunpodInferenceAdapter(
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -218,7 +218,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
"strict": False,
},
}
if request.tools:

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -291,6 +292,16 @@ class _HfAdapter(
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -267,6 +268,16 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -507,6 +508,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,

View file

@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -260,6 +261,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
) -> EmbeddingsResponse:
raise NotImplementedError("embedding is not supported for watsonx")
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import logging
import struct
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@ -15,6 +17,9 @@ from llama_stack.apis.inference import (
EmbeddingTaskType,
InterleavedContentItem,
ModelStore,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
TextTruncation,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
@ -43,6 +48,50 @@ class SentenceTransformerEmbeddingMixin:
)
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
# Convert input to list format if it's a single string
input_list = [input] if isinstance(input, str) else input
if not input_list:
raise ValueError("Empty list not supported")
# Get the model and generate embeddings
model_obj = await self.model_store.get_model(model)
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
# Convert embeddings to the requested format
data = []
for i, embedding in enumerate(embeddings):
if encoding_format == "base64":
# Convert float array to base64 string
float_bytes = struct.pack(f"{len(embedding)}f", *embedding)
embedding_value = base64.b64encode(float_bytes).decode("ascii")
else:
# Default to float format
embedding_value = embedding.tolist()
data.append(
OpenAIEmbeddingData(
embedding=embedding_value,
index=i,
)
)
# Not returning actual token usage
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -35,6 +37,9 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
@ -264,6 +269,52 @@ class LiteLLMOpenAIMixin(
embeddings = [data["embedding"] for data in response["data"]]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self.model_store.get_model(model)
# Convert input to list if it's a string
input_list = [input] if isinstance(input, str) else input
# Call litellm embedding function
# litellm.drop_params = True
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
dimensions=dimensions,
)
# Convert response to OpenAI format
data = []
for i, embedding_data in enumerate(response["data"]):
# we encode to base64 if the encoding format is base64 in the request
if encoding_format == "base64":
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
embedding = base64.b64encode(byte_data).decode("utf-8")
else:
embedding = embedding_data["embedding"]
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
async def openai_completion(
self,
model: str,

View file

@ -36,6 +36,10 @@ class RedisKVStoreConfig(CommonConfig):
def url(self) -> str:
return f"redis://{self.host}:{self.port}"
@property
def pip_packages(self) -> list[str]:
return ["redis"]
@classmethod
def sample_run_config(cls):
return {
@ -53,6 +57,10 @@ class SqliteKVStoreConfig(CommonConfig):
description="File path for the sqlite database",
)
@property
def pip_packages(self) -> list[str]:
return ["aiosqlite"]
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
return {
@ -65,22 +73,22 @@ class SqliteKVStoreConfig(CommonConfig):
class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost"
port: int = 5432
port: str = "5432"
db: str = "llamastack"
user: str
password: str | None = None
table_name: str = "llamastack_kvstore"
@classmethod
def sample_run_config(cls, table_name: str = "llamastack_kvstore"):
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
return {
"type": "postgres",
"namespace": None,
"host": "${env.POSTGRES_HOST:localhost}",
"port": "${env.POSTGRES_PORT:5432}",
"db": "${env.POSTGRES_DB}",
"user": "${env.POSTGRES_USER}",
"password": "${env.POSTGRES_PASSWORD}",
"db": "${env.POSTGRES_DB:llamastack}",
"user": "${env.POSTGRES_USER:llamastack}",
"password": "${env.POSTGRES_PASSWORD:llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
}
@ -100,6 +108,10 @@ class PostgresKVStoreConfig(CommonConfig):
raise ValueError("Table name must be less than 63 characters")
return v
@property
def pip_packages(self) -> list[str]:
return ["psycopg2-binary"]
class MongoDBKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
@ -110,6 +122,10 @@ class MongoDBKVStoreConfig(CommonConfig):
password: str | None = None
collection_name: str = "llamastack_kvstore"
@property
def pip_packages(self) -> list[str]:
return ["pymongo"]
@classmethod
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
return {

View file

@ -10,6 +10,13 @@ from .config import KVStoreConfig, KVStoreType
def kvstore_dependencies():
"""
Returns all possible kvstore dependencies for registry/provider specifications.
NOTE: For specific kvstore implementations, use config.pip_packages instead.
This function returns the union of all dependencies for cases where the specific
kvstore type is not known at declaration time (e.g., provider registries).
"""
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]

View file

@ -171,6 +171,22 @@ def make_overlapped_chunks(
return chunks
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
"""Helper method to validate embedding format and dimensions"""
if not isinstance(embedding, (list | np.ndarray)):
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
if isinstance(embedding, np.ndarray):
if not np.issubdtype(embedding.dtype, np.number):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
else:
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
if len(embedding) != expected_dimension:
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
@ -199,11 +215,22 @@ class VectorDBWithIndex:
self,
chunks: list[Chunk],
) -> None:
embeddings_response = await self.inference_api.embeddings(
self.vector_db.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
chunks_to_embed = []
for i, c in enumerate(chunks):
if c.embedding is None:
chunks_to_embed.append(c)
else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
if chunks_to_embed:
resp = await self.inference_api.embeddings(
self.vector_db.embedding_model,
[c.content for c in chunks_to_embed],
)
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
c.embedding = embedding
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_chunks(

View file

@ -19,10 +19,10 @@ from sqlalchemy import (
Text,
select,
)
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from ..api import ColumnDefinition, ColumnType, SqlStore
from ..sqlstore import SqliteSqlStoreConfig
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,
@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
}
class SqliteSqlStoreImpl(SqlStore):
def __init__(self, config: SqliteSqlStoreConfig):
self.engine = create_async_engine(config.engine_str)
class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqlAlchemySqlStoreConfig):
self.config = config
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
self.metadata = MetaData()
async def create_table(
@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore):
# Create the table in the database if it doesn't exist
# checkfirst=True ensures it doesn't try to recreate if it's already there
async with self.engine.begin() as conn:
engine = create_async_engine(self.config.engine_str)
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
async with self.engine.begin() as conn:
await conn.execute(self.metadata.tables[table].insert(), data)
await conn.commit()
async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()
async def fetch_all(
self,
@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore):
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> list[dict[str, Any]]:
async with self.engine.begin() as conn:
async with self.async_session() as session:
query = select(self.metadata.tables[table])
if where:
for key, value in where.items():
@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore):
query = query.order_by(self.metadata.tables[table].c[name].desc())
else:
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
result = await conn.execute(query)
result = await session.execute(query)
if result.rowcount == 0:
return []
return [dict(row._mapping) for row in result]
@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore):
if not where:
raise ValueError("where is required for update")
async with self.engine.begin() as conn:
async with self.async_session() as session:
stmt = self.metadata.tables[table].update()
for key, value in where.items():
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await conn.execute(stmt, data)
await conn.commit()
await session.execute(stmt, data)
await session.commit()
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
if not where:
raise ValueError("where is required for delete")
async with self.engine.begin() as conn:
async with self.async_session() as session:
stmt = self.metadata.tables[table].delete()
for key, value in where.items():
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await conn.execute(stmt)
await conn.commit()
await session.execute(stmt)
await session.commit()

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Annotated, Literal
@ -15,13 +16,26 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum):
sqlite = "sqlite"
postgres = "postgres"
class SqliteSqlStoreConfig(BaseModel):
class SqlAlchemySqlStoreConfig(BaseModel):
@property
@abstractmethod
def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs
@property
def pip_packages(self) -> list[str]:
return ["sqlalchemy[asyncio]"]
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
@ -39,18 +53,37 @@ class SqliteSqlStoreConfig(BaseModel):
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
)
# TODO: move this when we have a better way to specify dependencies with internal APIs
@property
def pip_packages(self) -> list[str]:
return ["sqlalchemy[asyncio]"]
return super().pip_packages + ["aiosqlite"]
class PostgresSqlStoreConfig(BaseModel):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value
host: str = "localhost"
port: str = "5432"
db: str = "llamastack"
user: str
password: str | None = None
@property
def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@property
def pip_packages(self) -> list[str]:
raise NotImplementedError("Postgres is not implemented yet")
return super().pip_packages + ["asyncpg"]
@classmethod
def sample_run_config(cls, **kwargs):
return cls(
type="postgres",
host="${env.POSTGRES_HOST:localhost}",
port="${env.POSTGRES_PORT:5432}",
db="${env.POSTGRES_DB:llamastack}",
user="${env.POSTGRES_USER:llamastack}",
password="${env.POSTGRES_PASSWORD:llamastack}",
)
SqlStoreConfig = Annotated[
@ -60,12 +93,10 @@ SqlStoreConfig = Annotated[
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type == SqlStoreType.sqlite.value:
from .sqlite.sqlite import SqliteSqlStoreImpl
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqliteSqlStoreImpl(config)
elif config.type == SqlStoreType.postgres.value:
raise NotImplementedError("Postgres is not implemented yet")
impl = SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {config.type}")