Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-05-19 09:23:07 -04:00
commit 51b68b4be6
234 changed files with 21943 additions and 7540 deletions

View file

@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
persistence_store: KVStore,
created_at: str,
):
self.agent_id = agent_id
self.agent_config = agent_config
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at
ShieldRunnerMixin.__init__(
self,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from llama_stack.apis.agents import (
Agent,
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
OpenAIResponseInputMessage,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
Session,
Turn,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
Inference,
ToolConfig,
@ -39,13 +38,14 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents):
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
created_at = datetime.now(timezone.utc)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
value=agent_info.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
if not agent_info_json:
raise ValueError(f"Could not find agent info for {agent_id}")
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
agent_info = AgentInfo.model_validate_json(agent_info_json)
except Exception as e:
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
agent_config=agent_info,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
)
async def create_agent_session(
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}")
async def shutdown(self) -> None:
pass
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
async def list_agents(self) -> ListAgentsResponse:
pass
# Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent:
pass
chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
)
return agent
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
self, agent_id: str, start_index: int | None = None, limit: int | None = None
) -> PaginatedResponse:
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass
# OpenAI responses
@ -255,7 +311,7 @@ class MetaReferenceAgentsImpl(Agents):
async def create_openai_response(
self,
input: str | list[OpenAIResponseInputMessage],
input: str | list[OpenAIResponseInput],
model: str,
previous_response_id: str | None = None,
store: bool | None = True,

View file

@ -7,22 +7,29 @@
import json
import uuid
from collections.abc import AsyncIterator
from typing import cast
from typing import Any, cast
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessage,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputItemList,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFunction,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseOutput,
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.inference.inference import (
@ -32,10 +39,13 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
@ -50,31 +60,110 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
async def _convert_response_content_to_chat_content(
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def _convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = []
for output_message in previous_response.output:
if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
if isinstance(input, list):
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await _convert_response_content_to_chat_content(input_item.content)
message_type = await _get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
output_messages = []
for choice in choices:
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
# TODO: handle image content
output_messages.append(
OpenAIResponseOutputMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
)
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
"""
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
"""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return output_messages
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
role="assistant",
)
async def _get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: OpenAIResponseInputItemList
response: OpenAIResponseObject
class OpenAIResponsesImpl:
@ -90,19 +179,45 @@ class OpenAIResponsesImpl:
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
response_json = await self.persistence_store.get(key=key)
if response_json is None:
raise ValueError(f"OpenAI response with id '{id}' not found")
return OpenAIResponseObject.model_validate_json(response_json)
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
):
if previous_response_id:
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input_items.data
# previous response output items
new_input_items.extend(previous_response_with_input.response.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
response_with_input = await self._get_previous_response_with_input(id)
return response_with_input.response
async def create_openai_response(
self,
input: str | list[OpenAIResponseInputMessage],
input: str | list[OpenAIResponseInput],
model: str,
previous_response_id: str | None = None,
store: bool | None = True,
@ -112,31 +227,8 @@ class OpenAIResponsesImpl:
):
stream = False if stream is None else stream
messages: list[OpenAIMessageParam] = []
if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
if isinstance(user_input.content, list):
for user_input_content in user_input.content:
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
if user_input_content.image_url:
image_url = OpenAIImageURL(
url=user_input_content.image_url, detail=user_input_content.detail
)
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
else:
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
else:
user_content = input
messages.append(OpenAIUserMessageParam(content=user_content))
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_response = await self.inference_api.openai_chat_completion(
model=model,
@ -150,6 +242,7 @@ class OpenAIResponsesImpl:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
@ -163,7 +256,30 @@ class OpenAIResponsesImpl:
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
@ -181,12 +297,26 @@ class OpenAIResponsesImpl:
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: list[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
)
else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if isinstance(tools[0], OpenAIResponseInputToolFunction):
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
else:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@ -195,13 +325,43 @@ class OpenAIResponsesImpl:
status="completed",
output=output_messages,
)
logger.debug(f"OpenAI Responses response: {response}")
if store:
# Store in kvstore
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
input_items = OpenAIResponseInputItemList(data=input_items_data)
prev_response = OpenAIResponsePreviousResponseWithInputItems(
input_items=input_items,
response=response,
)
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
await self.persistence_store.set(
key=key,
value=response.model_dump_json(),
value=prev_response.model_dump_json(),
)
if stream:
@ -221,7 +381,9 @@ class OpenAIResponsesImpl:
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "web_search":
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition(
@ -247,12 +409,11 @@ class OpenAIResponsesImpl:
self,
model_id: str,
stream: bool,
chat_response: OpenAIChatCompletion,
choice: OpenAIChoice,
messages: list[OpenAIMessageParam],
temperature: float,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
@ -262,6 +423,9 @@ class OpenAIResponsesImpl:
if not choice.message.tool_calls:
return output_messages
# Copy the messages list to avoid mutating the original list
messages = messages.copy()
# Add the assistant message with tool_calls response to the messages list
messages.append(choice.message)
@ -307,7 +471,9 @@ class OpenAIResponsesImpl:
)
# type cast to appease mypy
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages

View file

@ -9,9 +9,7 @@ import logging
import uuid
from datetime import datetime, timezone
from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
access_attributes: AccessAttributes | None = None
class AgentInfo(AgentConfig):
created_at: datetime
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id
@ -46,6 +46,7 @@ class AgentPersistence:
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
turns=[],
)
await self.kvstore.set(
@ -109,7 +110,7 @@ class AgentPersistence:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range(
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
@ -121,7 +122,6 @@ class AgentPersistence:
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
@ -170,3 +170,43 @@ class AgentPersistence:
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -11,9 +11,9 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .config import LocalFSDatasetIOConfig
@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)

View file

@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
# Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_benchmarks = await self.kvstore.range(start_key, end_key)
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark)

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:

View file

@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
def evacuate_model_from_device(model, device: str):
"""Safely clear a model from memory and free device resources.
This function handles the proper cleanup of a model by:
1. Moving the model to CPU if it's on a non-CPU device
2. Deleting the model object to free memory
3. Running garbage collection
4. Clearing CUDA cache if the model was on a CUDA device
Args:
model: The PyTorch model to clear
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
Note:
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
- For MPS devices, only moves the model to CPU (no cache clearing available)
- For CPU devices, only deletes the model object and runs garbage collection
"""
if device != "cpu":
model.to("cpu")
del model
gc.collect()
if device == "cuda":
# we need to import such that this is only imported when the method is called
import torch
torch.cuda.empty_cache()

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import HuggingFacePostTrainingConfig
# post_training api and the huggingface provider is still experimental and under heavy development
async def get_provider_impl(
config: HuggingFacePostTrainingConfig,
deps: dict[Api, Any],
):
from .post_training import HuggingFacePostTrainingImpl
impl = HuggingFacePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel
class HuggingFacePostTrainingConfig(BaseModel):
# Device to run training on (cuda, cpu, mps)
device: str = "cuda"
# Distributed training backend if using multiple devices
# fsdp: Fully Sharded Data Parallel
# deepspeed: DeepSpeed ZeRO optimization
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
# Format for saving model checkpoints
# full_state: Save complete model state
# huggingface: Save in HuggingFace format (recommended for compatibility)
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
# Template for formatting chat inputs and outputs
# Used to structure the conversation format for training
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
# Model-specific configuration parameters
# trust_remote_code: Allow execution of custom model code
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
model_specific_config: dict = {
"trust_remote_code": True,
"attn_implementation": "sdpa",
}
# Maximum sequence length for training
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
# Longer sequences may cause memory issues on MPS devices
max_seq_length: int = 2048
# Enable gradient checkpointing to reduce memory usage
# Trades computation for memory by recomputing activations
gradient_checkpointing: bool = False
# Maximum number of checkpoints to keep
# Older checkpoints are deleted when this limit is reached
save_total_limit: int = 3
# Number of training steps between logging updates
logging_steps: int = 10
# Ratio of training steps used for learning rate warmup
# Helps stabilize early training
warmup_ratio: float = 0.1
# L2 regularization coefficient
# Helps prevent overfitting
weight_decay: float = 0.01
# Number of worker processes for data loading
# Higher values can improve data loading speed but increase memory usage
dataloader_num_workers: int = 4
# Whether to pin memory in data loader
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class HuggingFacePostTrainingImpl:
def __init__(
self,
config: HuggingFacePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
resources_allocated, checkpoints = await recipe.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError("DPO alignment is not implemented yet")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -0,0 +1,683 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import json
import logging
import multiprocessing
import os
import signal
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import psutil
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
TrainingConfig,
)
from ..config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def get_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
"""Validate that the dataset has the required fields."""
required_fields = ["input_query", "expected_answer", "chat_completion_input"]
return all(field in row for row in rows for field in required_fields)
def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in instruct format."""
if "chat_completion_input" in row and "expected_answer" in row:
try:
messages = json.loads(row["chat_completion_input"])
if not isinstance(messages, list) or len(messages) != 1:
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
return None, None
if "content" not in messages[0]:
logger.warning(f"Message missing content: {messages[0]}")
return None, None
return messages[0]["content"], row["expected_answer"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
return None, None
return None, None
def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in dialog format."""
if "dialog" in row:
try:
dialog = json.loads(row["dialog"])
if not isinstance(dialog, list) or len(dialog) < 2:
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
return None, None
if dialog[0].get("role") != "user":
logger.warning(f"First message must be from user: {dialog[0]}")
return None, None
if not any(msg.get("role") == "assistant" for msg in dialog):
logger.warning("Dialog must have at least one assistant message")
return None, None
# Convert to human/gpt format
role_map = {"user": "human", "assistant": "gpt"}
conversations = []
for msg in dialog:
if "role" not in msg or "content" not in msg:
logger.warning(f"Message missing role or content: {msg}")
continue
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
# Format as a single conversation
return conversations[0]["value"], conversations[1]["value"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse dialog: {row['dialog']}")
return None, None
return None, None
def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row using fallback formats."""
if "input" in row and "output" in row:
return row["input"], row["output"]
elif "prompt" in row and "completion" in row:
return row["prompt"], row["completion"]
elif "question" in row and "answer" in row:
return row["question"], row["answer"]
return None, None
def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format input and output text based on model requirements."""
if hasattr(provider_config, "chat_template"):
return provider_config.chat_template.format(input=input_text, output=output_text)
return f"{input_text}\n{output_text}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset."""
formatted_rows = []
for row in rows:
input_text = None
output_text = None
# Process based on format
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
if config.data_config.data_format.value == "instruct":
input_text, output_text = self._process_instruct_format(row)
elif config.data_config.data_format.value == "dialog":
input_text, output_text = self._process_dialog_format(row)
else:
input_text, output_text = self._process_fallback_format(row)
if input_text and output_text:
formatted_text = self._format_text(input_text, output_text, provider_config)
formatted_rows.append({"text": formatted_text})
if not formatted_rows:
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
raise ValueError(
f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}"
)
return Dataset.from_list(formatted_rows)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer."""
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding=True,
truncation=True,
max_length=provider_config.max_seq_length,
return_tensors=None,
)
return ds.map(
tokenize_function,
batched=True,
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Synchronous wrapper for running training process.
This method serves as a bridge between the multiprocessing Process and the async training function.
It creates a new event loop to run the async training process.
Args:
model: The model identifier to load
dataset_id: ID of the dataset to use for training
provider_config: Configuration specific to the HuggingFace provider
peft_config: Optional LoRA configuration
config: General training configuration
output_dir_path: Optional path to save the model
"""
import asyncio
logger.info("Starting training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
provider_config=provider_config,
peft_config=peft_config,
config=config,
output_dir_path=output_dir_path,
)
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for training.
Args:
model: The model identifier to load
config: Training configuration
provider_config: Provider-specific configuration
Returns:
tuple: (train_dataset, eval_dataset, tokenizer)
"""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set pad token to eos token if not present
# This is common for models that don't have a dedicated pad token
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Set padding side to right for causal language modeling
# This ensures that padding tokens don't interfere with the model's ability
# to predict the next token in the sequence
tokenizer.padding_side = "right"
# Set truncation side to right to keep the beginning of the sequence
# This is important for maintaining context and instruction format
tokenizer.truncation_side = "right"
# Set model max length to match provider config
# This ensures consistent sequence lengths across the training process
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args(
self,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
device: torch.device,
output_dir_path: Path | None,
steps_per_epoch: int,
) -> SFTConfig:
"""Setup training arguments.
Args:
config: Training configuration
provider_config: Provider-specific configuration
device: The device to train on
output_dir_path: Optional path to save the model
steps_per_epoch: Number of steps per epoch
Returns:
Configured SFTConfig object
"""
logger.info("Configuring training arguments")
lr = 2e-5
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Eval steps: {eval_steps}")
logger.info(f"- Save steps: {save_steps}")
logger.info(f"- Logging steps: {logging_steps}")
# Configure save strategy
save_strategy = "no"
if output_dir_path:
save_strategy = "steps"
logger.info(f"Will save checkpoints to {output_dir_path}")
return SFTConfig(
max_steps=max_steps,
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
fp16=device.type == "cuda",
bf16=False, # Causes CPU issues.
eval_strategy="steps",
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
report_to="none",
max_seq_length=provider_config.max_seq_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
dataset_text_field="text",
packing=False,
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
eval_steps=eval_steps,
save_steps=save_steps,
logging_steps=logging_steps,
)
def save_model(
self,
model_obj: AutoModelForCausalLM,
trainer: SFTTrainer,
peft_config: LoraConfig | None,
output_dir_path: Path,
) -> None:
"""Save the trained model.
Args:
model_obj: The model to save
trainer: The trainer instance
peft_config: Optional LoRA configuration
output_dir_path: Path to save the model
"""
logger.info("Saving final model")
model_obj.config.use_cache = True
if peft_config:
logger.info("Merging LoRA weights with base model")
model_obj = trainer.model.merge_and_unload()
else:
model_obj = trainer.model
save_path = output_dir_path / "merged_model"
logger.info(f"Saving model to {save_path}")
model_obj.save_pretrained(save_path)
async def _run_training(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Run the training process with signal handling."""
def signal_handler(signum, frame):
"""Handle termination signals gracefully."""
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
training_args = self.setup_training_args(
config_obj,
provider_config_obj,
device,
output_dir_path,
steps_per_epoch,
)
# Load model
model_obj = self.load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
trainer = SFTTrainer(
model=model_obj,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_args,
)
try:
# Train
logger.info("Starting training")
trainer.train()
logger.info("Training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
self.save_model(model_obj, trainer, peft_config, output_dir_path)
finally:
# Clean up resources
logger.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
del trainer
gc.collect()
logger.info("Cleanup completed")
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
lora_config: LoraFinetuningConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats
memory_stats = {
"initial": get_memory_stats(device),
"after_training": None,
"final": None,
}
# Configure LoRA
peft_config = None
if lora_config:
peft_config = LoraConfig(
lora_alpha=lora_config.alpha,
lora_dropout=0.1,
r=lora_config.rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_config.lora_attn_modules,
)
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting training in separate process")
try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,
kwargs={
"model": model,
"provider_config": provider_config.model_dump(),
"peft_config": peft_config,
"config": config.model_dump(),
"output_dir_path": output_dir_path,
},
)
process.start()
# Monitor the process
while process.is_alive():
process.join(timeout=1) # Check every second
if not process.is_alive():
break
# Get the return code
if process.exitcode != 0:
raise RuntimeError(f"Training failed with exit code {process.exitcode}")
memory_stats["after_training"] = get_memory_stats(device)
checkpoints = None
if output_dir_path:
# Create checkpoint
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(output_dir_path / "merged_model"),
)
checkpoints = [checkpoint]
return memory_stats, checkpoints
finally:
memory_stats["final"] = get_memory_stats(device)
gc.collect()

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import logging
import os
import time
@ -39,7 +38,6 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
@ -48,6 +46,7 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
@ -90,8 +89,6 @@ class LoraFinetuningSingleDevice:
) -> None:
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
@ -557,11 +554,7 @@ class LoraFinetuningSingleDevice:
checkpoints.append(checkpoint)
# clean up the memory after training finishes
if self._device.type != "cpu":
self._model.to("cpu")
torch.cuda.empty_cache()
del self._model
gc.collect()
evacuate_model_from_device(self._model, self._device.type)
return (memory_stats, checkpoints)

View file

@ -20,7 +20,10 @@ from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
MetricLabelMatcher,
MetricQueryType,
QueryCondition,
QueryMetricsResponse,
QuerySpanTreeResponse,
QueryTracesResponse,
Span,
@ -123,6 +126,17 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
else:
raise ValueError(f"Unknown event type: {event}")
async def query_metrics(
self,
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
raise NotImplementedError("Querying metrics is not implemented")
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage

View file

@ -87,6 +87,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
doc.metadata,
)
)
@ -105,7 +106,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
@ -140,19 +143,21 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, c in enumerate(chunks):
metadata = c.metadata
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
tokens += metadata["metadata_token_count"]
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(
TextContentItem(
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
)
)
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(

View file

@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load existing banks from kvstore
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.range(start_key, end_key)
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)

View file

@ -280,11 +280,10 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=[
"openai",
],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
),
),
remote_provider_spec(

View file

@ -21,6 +21,17 @@ def available_providers() -> list[ProviderSpec]:
Api.datasets,
],
),
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::huggingface",
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
remote_provider_spec(
api=Api.post_training,
adapter=AdapterSpec(

View file

@ -12,8 +12,8 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .config import HuggingfaceDatasetIOConfig
@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import CerebrasCompatConfig
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .cerebras import CerebrasCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import FireworksCompatConfig
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .fireworks import FireworksCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import GroqCompatConfig
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .groq import GroqCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -61,6 +61,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@ -81,7 +82,7 @@ logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:
@ -139,6 +140,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -202,6 +205,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -346,6 +351,8 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
@ -389,29 +396,25 @@ class OllamaInferenceAdapter(
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"prompt": prompt,
"best_of": best_of,
"echo": echo,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_tokens": max_tokens,
"n": n,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
@ -441,41 +444,31 @@ class OllamaInferenceAdapter(
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model)
# ollama still makes tool calls even when tool_choice is "none"
# so we need to remove the tools in that case
if tool_choice == "none" and tools is not None:
tools = None
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"messages": messages,
"frequency_penalty": frequency_penalty,
"function_call": function_call,
"functions": functions,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"n": n,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_logprobs": top_logprobs,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self.openai_client.chat.completions.create(**params) # type: ignore
async def batch_completion(

View file

@ -4,27 +4,60 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
# the models w/ "openai/" prefix are the litellm specific model names.
# they should be deprecated in favor of the canonical openai model names.
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/chatgpt-4o-latest",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-4",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"gpt-4o-audio-preview",
"chatgpt-4o-latest",
"o1",
"o1-mini",
"o3-mini",
"o4-mini",
]
@dataclass
class EmbeddingModelInfo:
"""Structured representation of embedding model information."""
embedding_dimension: int
context_length: int
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-small",
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1536, "context_length": 8192},
),
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-large",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 3072, "context_length": 8192},
),
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]

View file

@ -4,12 +4,41 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
#
# This OpenAI adapter implements Inference methods using two clients -
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
# | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI |
# | openai_chat_completion | AsyncOpenAI |
#
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
@ -19,9 +48,120 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="openai_api_key",
)
self.config = config
# we set is_openai_compat so users can use the canonical
# openai model names like "gpt-4" or "gpt-3.5-turbo"
# and the model name will be translated to litellm's
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
self._openai_client = AsyncOpenAI(
api_key=self.config.api_key,
)
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._openai_client.completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._openai_client.chat.completions.create(**params)

View file

@ -4,16 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.inference import Inference
from .config import SambaNovaImplConfig
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference:
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"

View file

@ -6,25 +6,32 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="Sambanova Cloud API key",
)
@json_schema_type
class SambaNovaImplConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: str | None = Field(
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova.ai API Key",
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": "${env.SAMBANOVA_API_KEY}",
"api_key": api_key,
}

View file

@ -11,43 +11,43 @@ from llama_stack.providers.utils.inference.model_registry import (
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Meta-Llama-3.1-8B-Instruct",
"sambanova/Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.1-70B-Instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.1-405B-Instruct",
"sambanova/Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.2-1B-Instruct",
"sambanova/Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.2-3B-Instruct",
"sambanova/Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.3-70B-Instruct",
"sambanova/Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-3.2-11B-Vision-Instruct",
"sambanova/Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"Llama-3.2-90B-Vision-Instruct",
"sambanova/Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct",
"sambanova/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]

View file

@ -5,305 +5,249 @@
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator
from collections.abc import Iterable
from openai import OpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
GreedySamplingStrategy,
Inference,
LogProbConfig,
JsonSchemaResponseFormat,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
TextTruncation,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
convert_tooldef_to_openai_tool,
get_sampling_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class SambaNovaInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
return
async def convert_message_to_openai_dict_with_b64_images(
message: Message | dict,
) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "tool":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
async def shutdown(self) -> None:
pass
def _get_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self.config.api_key)
async def completion(
self,
model_id: str,
# Map Llama Stack spec to OpenAI spec -
# str -> str
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
async def _convert_message_content(
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = ToolPromptFormat.json,
stream: bool | None = False,
tool_config: ToolConfig | None = None,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
request_sambanova = await self.convert_chat_completion_request(request)
if stream:
return self._stream_chat_completion(request_sambanova)
else:
return await self._nonstream_chat_completion(request_sambanova)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=self.convert_to_sambanova_finish_reason(choice.finish_reason),
tool_calls=self.convert_to_sambanova_tool_calls(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _to_async_generator():
streaming = self._get_client().chat.completions.create(**request)
for chunk in streaming:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict:
compatible_request = self.convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = await self.convert_to_sambanova_messages(request.messages)
compatible_request["stream"] = request.stream
compatible_request["logprobs"] = False
compatible_request["extra_headers"] = {
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
}
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
return compatible_request
def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict:
params = {}
if sampling_params:
params["frequency_penalty"] = sampling_params.repetition_penalty
if sampling_params.max_tokens:
if legacy:
params["max_tokens"] = sampling_params.max_tokens
else:
params["max_completion_tokens"] = sampling_params.max_tokens
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
params["top_p"] = sampling_params.strategy.top_p
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
params["extra_body"]["top_k"] = sampling_params.strategy.top_k
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
params["temperature"] = 0.0
return params
async def convert_to_sambanova_messages(self, messages: list[Message]) -> list[dict]:
conversation = []
for message in messages:
content = {}
content["content"] = await self.convert_to_sambanova_content(message)
if isinstance(message, UserMessage):
content["role"] = "user"
elif isinstance(message, CompletionMessage):
content["role"] = "assistant"
tools = []
for tool_call in message.tool_calls:
tools.append(
{
"id": tool_call.call_id,
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.arguments),
},
"type": "function",
}
)
content["tool_calls"] = tools
elif isinstance(message, ToolResponseMessage):
content["role"] = "tool"
content["tool_call_id"] = message.call_id
elif isinstance(message, SystemMessage):
content["role"] = "system"
conversation.append(content)
return conversation
async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
url = await convert_image_content_to_url(content, download=True)
# A fix to make sure the call sucess.
components = url.split(";base64")
url = f"{components[0].lower()};base64{components[1]}"
return {
"type": "image_url",
"image_url": {"url": url},
}
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
async def impl(
content_: InterleavedContent,
) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
elif isinstance(content_, TextContentItem):
return OpenAIChatCompletionContentPartTextParam(
type="text",
text=content_.text,
)
elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_, download=True)),
)
elif isinstance(content_, list):
return [await impl(item) for item in content_]
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
raise ValueError(f"Unsupported content type: {type(content_)}")
if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict
content = [await _convert_content(c) for c in message.content]
ret = await impl(content)
# OpenAI*Message expects a str or list
if isinstance(ret, str) or isinstance(ret, list):
return ret
else:
content = message.content
return [ret]
return content
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=await _convert_message_content(message.content),
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
]
or None,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=await _convert_message_content(message.content),
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=await _convert_message_content(message.content),
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
def convert_to_sambanova_tool(self, tools: list[ToolDefinition]) -> list[dict]:
if tools is None:
return tools
return out
compatiable_tools = []
for tool in tools:
properties = {}
compatiable_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compatiable_required.append(tool_key)
class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaImplConfig
compatiable_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compatiable_required,
},
def __init__(self, config: SambaNovaImplConfig):
self.config = config
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=self.config.api_key,
provider_data_api_key_field="sambanova_api_key",
)
def _get_api_key(self) -> str:
config_api_key = self.config.api_key if self.config.api_key else None
if config_api_key:
return config_api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.sambanova_api_key:
raise ValueError(
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
)
return provider_data.sambanova_api_key
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
input_dict["messages"] = [await convert_message_to_openai_dict_with_b64_images(m) for m in request.messages]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
compatiable_tools.append(compatiable_tool)
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self._get_api_key()
if len(compatiable_tools) > 0:
return compatiable_tools
return None
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
"model": request.model,
"api_key": api_key,
"api_base": self.config.url,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def convert_to_sambanova_tool_calls(
self,
tool_calls,
) -> list[ToolCall]:
if not tool_calls:
return []
async def initialize(self):
await super().initialize()
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
return compitable_tool_calls
async def shutdown(self):
await super().shutdown()

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import SambaNovaCompatConfig
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .sambanova import SambaNovaCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import TogetherCompatConfig
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .together import TogetherCompatInferenceAdapter

View file

@ -158,27 +158,28 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start
tool_call_buf = UnparseableToolCall()
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
continue
choice = chunk.choices[0]
if choice.finish_reason:
args_str = tool_call_buf.arguments
args = None
try:
args = {} if not args_str else json.loads(args_str)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
if args:
yield ChatCompletionResponseStreamChunk(
def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
if finish_reason is not None:
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
else:
stop_reason = StopReason.end_of_message
tool_call_bufs = tool_call_bufs or {}
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
args = json.loads(args_str)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
event_type=current_event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
@ -190,8 +191,12 @@ async def _process_vllm_chat_completion_stream_response(
),
)
)
elif args_str:
yield ChatCompletionResponseStreamChunk(
)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
@ -200,21 +205,62 @@ async def _process_vllm_chat_completion_stream_response(
),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
else:
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=last_chunk_content or ""),
logprobs=None,
stop_reason=stop_reason,
)
)
)
return chunks
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_bufs=tool_call_bufs,
)
for c in chunks:
yield c
end_of_stream_processed = True
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
@ -224,6 +270,17 @@ async def _process_vllm_chat_completion_stream_response(
)
event_type = ChatCompletionResponseEventType.progress
if end_of_stream_processed:
return
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
)
for c in chunks:
yield c
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
@ -272,6 +329,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -302,6 +361,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice

View file

@ -26,8 +26,7 @@ from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = logging.getLogger(__name__)
ChromaClientType = chromadb.AsyncHttpClient | chromadb.PersistentClient
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
# this is a helper to allow us to use async and non-async chroma clients interchangeably

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,55 +0,0 @@
inference:
tests:
- inference/test_vision_inference.py::test_vision_chat_completion_streaming
- inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
- inference/test_text_inference.py::test_structured_output
- inference/test_text_inference.py::test_chat_completion_streaming
- inference/test_text_inference.py::test_chat_completion_non_streaming
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
scenarios:
- provider_fixtures:
inference: ollama
- fixture_combo_id: fireworks
- provider_fixtures:
inference: together
# - inference: tgi
# - inference: vllm_remote
inference_models:
- meta-llama/Llama-3.1-8B-Instruct
- meta-llama/Llama-3.2-11B-Vision-Instruct
agents:
tests:
- agents/test_agents.py::test_agent_turns_with_safety
- agents/test_agents.py::test_rag_agent
scenarios:
- fixture_combo_id: ollama
- fixture_combo_id: together
- fixture_combo_id: fireworks
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
safety_shield: meta-llama/Llama-Guard-3-1B
memory:
tests:
- memory/test_memory.py::test_query_documents
scenarios:
- fixture_combo_id: ollama
- provider_fixtures:
inference: sentence_transformers
memory: faiss
- fixture_combo_id: chroma
inference_models:
- meta-llama/Llama-3.2-1B-Instruct
embedding_model: all-MiniLM-L6-v2

View file

@ -1,296 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections import defaultdict
from pathlib import Path
from typing import Any
import pytest
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from termcolor import colored
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail
from .report import Report
class ProviderFixture(BaseModel):
providers: list[Provider]
provider_data: dict[str, Any] | None = None
class TestScenario(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: dict[str, str] = Field(default_factory=dict)
fixture_combo_id: str | None = None
class APITestConfig(BaseModel):
scenarios: list[TestScenario] = Field(default_factory=list)
inference_models: list[str] = Field(default_factory=list)
# test name format should be <relative_path.py>::<test_name>
tests: list[str] = Field(default_factory=list)
class MemoryApiTestConfig(APITestConfig):
embedding_model: str | None = Field(default_factory=None)
class AgentsApiTestConfig(APITestConfig):
safety_shield: str | None = Field(default_factory=None)
class TestConfig(BaseModel):
inference: APITestConfig | None = None
agents: AgentsApiTestConfig | None = None
memory: MemoryApiTestConfig | None = None
def get_test_config_from_config_file(metafunc_config):
config_file = metafunc_config.getoption("--config")
if config_file is None:
return None
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path) as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_test_config_for_api(metafunc_config, api):
test_config = get_test_config_from_config_file(metafunc_config)
if test_config is None:
return None
return getattr(test_config, api)
def get_provider_fixture_overrides_from_test_config(metafunc_config, api, default_provider_fixture_combinations):
api_config = get_test_config_for_api(metafunc_config, api)
if api_config is None:
return None
fixture_combo_ids = set()
custom_provider_fixture_combos = []
for scenario in api_config.scenarios:
if scenario.fixture_combo_id:
fixture_combo_ids.add(scenario.fixture_combo_id)
else:
custom_provider_fixture_combos.append(
pytest.param(
scenario.provider_fixtures,
id=scenario.provider_fixtures.get("inference") or "",
)
)
if len(fixture_combo_ids) > 0:
for default_fixture in default_provider_fixture_combinations:
if default_fixture.id in fixture_combo_ids:
custom_provider_fixture_combos.append(default_fixture)
return custom_provider_fixture_combos
def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
else:
config = RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
)
return ProviderFixture(
providers=[
Provider(
provider_id="test::remote",
provider_type="test::remote",
config=config.model_dump(),
)
],
)
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True
"""Load environment variables at start of test run"""
# Load from .env file if it exists
env_file = Path(__file__).parent / ".env"
if env_file.exists():
load_dotenv(env_file)
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
if config.getoption("--output") is not None:
config.pluginmanager.register(Report(config.getoption("--output")))
def pytest_addoption(parser):
parser.addoption(
"--providers",
default="",
help=(
"Provider configuration in format: api1=provider1,api2=provider2. "
"Example: --providers inference=ollama,safety=meta-reference"
),
)
parser.addoption(
"--config",
action="store",
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
)
parser.addoption(
"--output",
action="store",
help="Set output file for test report, e.g. --output=pytest_report.md",
)
"""Add custom command line options"""
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
parser.addoption(
"--inference-model",
action="store",
default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
parser.addoption(
"--judge-model",
action="store",
default="meta-llama/Llama-3.1-8B-Instruct",
help="Specify the judge model to use for testing",
)
def make_provider_id(providers: dict[str, str]) -> str:
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
def get_provider_marks(providers: dict[str, str]) -> list[Any]:
marks = []
for provider in providers.values():
marks.append(getattr(pytest.mark, provider))
return marks
def get_provider_fixture_overrides(config, available_fixtures: dict[str, list[str]]) -> list[pytest.param] | None:
provider_str = config.getoption("--providers")
if not provider_str:
return None
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
return [
pytest.param(
fixture_dict,
id=make_provider_id(fixture_dict),
marks=get_provider_marks(fixture_dict),
)
]
def parse_fixture_string(provider_str: str, available_fixtures: dict[str, list[str]]) -> dict[str, str]:
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
if not provider_str:
return {}
fixtures = {}
pairs = provider_str.split(",")
for pair in pairs:
if "=" not in pair:
raise ValueError(f"Invalid provider specification: {pair}. Expected format: api=provider")
api, fixture = pair.split("=")
if api not in available_fixtures:
raise ValueError(f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}")
if fixture not in available_fixtures[api]:
raise ValueError(
f"Unknown provider '{fixture}' for API '{api}'. Available providers: {list(available_fixtures[api])}"
)
fixtures[api] = fixture
# Check that all provided APIs are supported
for api in available_fixtures.keys():
if api not in fixtures:
raise ValueError(
f"Missing provider fixture for API '{api}'. Available providers: {list(available_fixtures[api])}"
)
return fixtures
def pytest_itemcollected(item):
# Get all markers as a list
filtered = ("asyncio", "parametrize")
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
if marks:
marks = colored(",".join(marks), "yellow")
item.name = f"{item.name}[{marks}]"
def pytest_collection_modifyitems(session, config, items):
test_config = get_test_config_from_config_file(config)
if test_config is None:
return
required_tests = defaultdict(set)
for api_test_config in [
test_config.inference,
test_config.memory,
test_config.agents,
]:
if api_test_config is None:
continue
for test in api_test_config.tests:
arr = test.split("::")
if len(arr) != 2:
raise ValueError(f"Invalid format for test name {test}")
test_path, func_name = arr
required_tests[Path(__file__).parent / test_path].add(func_name)
new_items, deselected_items = [], []
for item in items:
func_name = getattr(item, "originalname", item.name)
if func_name in required_tests[item.fspath]:
new_items.append(item)
continue
deselected_items.append(item)
items[:] = new_items
config.hook.pytest_deselected(items=deselected_items)
pytest_plugins = [
"llama_stack.providers.tests.inference.fixtures",
"llama_stack.providers.tests.safety.fixtures",
"llama_stack.providers.tests.vector_io.fixtures",
"llama_stack.providers.tests.agents.fixtures",
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
"llama_stack.providers.tests.tools.fixtures",
]

View file

@ -1,176 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections import defaultdict
from pathlib import Path
import pytest
from pytest import ExitCode
from pytest_html.basereport import _process_outcome
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import CoreModelId
INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = {
"ollama": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
},
"fireworks": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
"together": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
}
class Report:
def __init__(self, output_path):
valid_file_format = (
output_path.split(".")[1] in ["md", "markdown"] if len(output_path.split(".")) == 2 else False
)
if not valid_file_format:
raise ValueError(f"Invalid output file {output_path}. Markdown file is required")
self.output_path = output_path
self.test_data = defaultdict(dict)
self.inference_tests = defaultdict(dict)
@pytest.hookimpl
def pytest_runtest_logreport(self, report):
# This hook is called in several phases, including setup, call and teardown
# The test is considered failed / error if any of the outcomes is not "Passed"
outcome = _process_outcome(report)
data = {
"outcome": report.outcome,
"longrepr": report.longrepr,
"name": report.nodeid,
}
if report.nodeid not in self.test_data:
self.test_data[report.nodeid] = data
elif self.test_data[report.nodeid] != outcome and outcome != "Passed":
self.test_data[report.nodeid] = data
@pytest.hookimpl
def pytest_sessionfinish(self, session, exitstatus):
if exitstatus <= ExitCode.INTERRUPTED:
return
report = []
report.append("# Llama Stack Integration Test Results Report")
report.append("\n## Summary")
report.append("\n## Supported Models: ")
header = "| Model Descriptor |"
dividor = "|:---|"
for k in SUPPORTED_MODELS.keys():
header += f"{k} |"
dividor += ":---:|"
report.append(header)
report.append(dividor)
rows = []
for model in all_registered_models():
if "Instruct" not in model.core_model_id.value and "Guard" not in model.core_model_id.value:
continue
row = f"| {model.core_model_id.value} |"
for k in SUPPORTED_MODELS.keys():
if model.core_model_id.value in SUPPORTED_MODELS[k]:
row += " ✅ |"
else:
row += " ❌ |"
rows.append(row)
report.extend(rows)
report.append("\n### Tests:")
for provider in SUPPORTED_MODELS.keys():
if provider not in self.inference_tests:
continue
report.append(f"\n #### {provider}")
test_table = [
"| Area | Model | API | Functionality Test | Status |",
"|:-----|:-----|:-----|:-----|:-----|",
]
for api in INFERENCE_APIS:
tests = self.inference_tests[provider][api]
for test_nodeid in tests:
row = "|{area} | {model} | {api} | {test} | {result} ".format(
area="Text" if "text" in test_nodeid else "Vision",
model=("Llama-3.1-8B-Instruct" if "text" in test_nodeid else "Llama3.2-11B-Vision-Instruct"),
api=f"/{api}",
test=self.get_simple_function_name(test_nodeid),
result=("" if self.test_data[test_nodeid]["outcome"] == "passed" else ""),
)
test_table += [row]
report.extend(test_table)
report.append("\n")
output_file = Path(self.output_path)
output_file.write_text("\n".join(report))
print(f"\n Report generated: {output_file.absolute()}")
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(self, session, config, items):
for item in items:
inference = item.callspec.params.get("inference_stack")
if "inference" in item.nodeid:
func_name = getattr(item, "originalname", item.name)
for api in INFERENCE_APIS:
if api in func_name:
api_tests = self.inference_tests[inference].get(api, set())
api_tests.add(item.nodeid)
self.inference_tests[inference][api] = api_tests
def get_simple_function_name(self, nodeid):
"""Extract function name from nodeid.
Examples:
- 'tests/test_math.py::test_addition' -> 'test_addition'
- 'tests/test_math.py::TestClass::test_method' -> test_method'
"""
parts = nodeid.split("::")
func_name = nodeid # Fallback to full nodeid if pattern doesn't match
if len(parts) == 2: # Simple function
func_name = parts[1]
elif len(parts) == 3: # Class method
func_name = parts[2]
return func_name.split("[")[0]

View file

@ -19,7 +19,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -59,9 +59,12 @@ logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
InferenceProvider,
NeedsRequestProviderData,
):
# TODO: avoid exposing the litellm specific model names to the user.
# potential change: add a prefix param that gets added to the model name
# when calling litellm.
def __init__(
self,
model_entries,
@ -92,7 +95,9 @@ class LiteLLMOpenAIMixin(
return model
def get_litellm_model_name(self, model_id: str) -> str:
return "openai/" + model_id if self.is_openai_compat else model_id
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
async def completion(
self,

View file

@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
tool_name = tc.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
# arguments_json can be None, so attempt it first and fall back to arguments
if hasattr(tc, "arguments_json") and tc.arguments_json:
arguments = tc.arguments_json
else:
arguments = json.dumps(tc.arguments)
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
"arguments": arguments,
},
}
)

View file

@ -382,7 +382,7 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:

View file

@ -16,4 +16,6 @@ class KVStore(Protocol):
async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> list[str]: ...
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...

View file

@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore):
async def set(self, key: str, value: str) -> None:
self._store[key] = value
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
return [key for key in self._store.keys() if key >= start_key and key < end_key]
async def delete(self, key: str) -> None:
del self._store[key]
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value:

View file

@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {
@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore):
async for doc in cursor:
result.append(doc["value"])
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {"key": {"$gte": start_key, "$lt": end_key}}
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
return [doc["key"] for doc in cursor]

View file

@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
(key,),
)
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore):
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]

View file

@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
cursor = 0
@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore):
]
return []
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
if not matching_keys:
return []
return [k.decode("utf-8") for k in matching_keys]

View file

@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore):
_, value, _ = row
result.append(value)
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]

View file

@ -118,45 +118,53 @@ async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
return interleaved_content_as_str(doc.content)
return r.text
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]:
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
try:
metadata_string = str(metadata)
except Exception as e:
raise ValueError("Failed to serialize metadata to string") from e
metadata_tokens = tokenizer.encode(metadata_string, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunk_metadata = metadata.copy()
chunk_metadata["document_id"] = document_id
chunk_metadata["token_count"] = len(toks)
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
# chunk is a string
chunks.append(
Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
metadata=chunk_metadata,
)
)