chore(package): migrate to src/ layout (#3920)

Migrates package structure to src/ layout following Python packaging
best practices.

All code moved from `llama_stack/` to `src/llama_stack/`. Public API
unchanged - imports remain `import llama_stack.*`.

Updated build configs, pre-commit hooks, scripts, and GitHub workflows
accordingly. All hooks pass, package builds cleanly.

**Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
Ashwin Bharambe 2025-10-27 12:02:21 -07:00 committed by GitHub
parent 98a5047f9d
commit 471b1b248b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
791 changed files with 2983 additions and 456 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Any, Protocol
from urllib.parse import urlparse
from pydantic import BaseModel, Field
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.schema_utils import json_schema_type
class ModelsProtocolPrivate(Protocol):
"""
Protocol for model management.
This allows users to register their preferred model identifiers.
Model registration requires -
- a provider, used to route the registration request
- a model identifier, user's intended name for the model during inference
- a provider model identifier, a model identifier supported by the provider
Providers will only accept registration for provider model ids they support.
Example,
register: provider x my-model-id x provider-model-id
-> Error if provider does not support provider-model-id
-> Error if my-model-id is already registered
-> Success if provider supports provider-model-id
inference: my-model-id x ...
-> Provider uses provider-model-id for inference
"""
# this should be called `on_model_register` or something like that.
# the provider should _not_ be able to change the object in this
# callback
async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ...
# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...
async def should_refresh_models(self) -> bool: ...
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
async def unregister_shield(self, identifier: str) -> None: ...
class VectorStoresProtocolPrivate(Protocol):
async def register_vector_store(self, vector_store: VectorStore) -> None: ...
async def unregister_vector_store(self, vector_store_id: str) -> None: ...
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> list[ScoringFn]: ...
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
class ToolGroupsProtocolPrivate(Protocol):
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_type: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: list[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
optional_api_dependencies: list[Api] = Field(
default_factory=list,
)
deprecation_warning: str | None = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",
)
deprecation_error: str | None = Field(
default=None,
description="If this provider is deprecated and does NOT work, specify the error message here",
)
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
provider_data_validator: str | None = Field(
default=None,
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now
deps__: list[str] = Field(default_factory=list)
@property
def is_sample(self) -> bool:
return self.provider_type in ("sample", "remote::sample")
class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ...
@json_schema_type
class InlineProviderSpec(ProviderSpec):
container_image: str | None = Field(
default=None,
description="""
The container image to use for this implementation. If one is provided, pip_packages will be ignored.
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
class RemoteProviderConfig(BaseModel):
host: str = "localhost"
port: int | None = None
protocol: str = "http"
@property
def url(self) -> str:
if self.port is None:
return f"{self.protocol}://{self.host}"
return f"{self.protocol}://{self.host}:{self.port}"
@classmethod
def from_url(cls, url: str) -> "RemoteProviderConfig":
parsed = urlparse(url)
attrs = {k: v for k, v in parsed._asdict().items() if v is not None}
return cls(**attrs)
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
@property
def container_image(self) -> str | None:
return None
class HealthStatus(StrEnum):
OK = "OK"
ERROR = "Error"
NOT_IMPLEMENTED = "Not Implemented"
HealthResponse = dict[str, Any]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import AccessRule, Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(
config: MetaReferenceAgentsImplConfig,
deps: dict[Api, Any],
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(
config,
deps[Api.inference],
deps[Api.vector_io],
deps[Api.safety],
deps[Api.tool_runtime],
deps[Api.tool_groups],
deps[Api.conversations],
policy,
telemetry_enabled,
)
await impl.initialize()
return impl

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,383 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
Order,
Session,
Turn,
)
from llama_stack.apis.agents.agents import ResponseGuardrail
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
vector_io_api: VectorIO,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
conversations_api: Conversations,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.config = config
self.inference_api = inference_api
self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
self.policy = policy
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence.agent_state)
self.responses_store = ResponsesStore(self.config.persistence.responses, self.policy)
await self.responses_store.initialize()
self.openai_responses_impl = OpenAIResponsesImpl(
inference_api=self.inference_api,
tool_groups_api=self.tool_groups_api,
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
safety_api=self.safety_api,
conversations_api=self.conversations_api,
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
created_at = datetime.now(UTC)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_info.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_info_json:
raise ValueError(f"Could not find agent info for {agent_id}")
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
except Exception as e:
raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_info,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session_id,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
for step in turn.steps:
if step.step_id == step_id:
return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids]
return Session(
session_name=session_info.session_name,
session_id=session_id,
turns=turns,
started_at=session_info.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}")
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
# Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent:
chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
)
return agent
async def list_agent_sessions(
self, agent_id: str, start_index: int | None = None, limit: int | None = None
) -> PaginatedResponse:
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass
# OpenAI responses
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input,
model,
instructions,
previous_response_id,
conversation,
store,
stream,
temperature,
text,
tools,
include,
max_infer_iters,
guardrails,
)
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order
)
async def delete_openai_response(self, response_id: str) -> None:
return await self.openai_responses_impl.delete_openai_response(response_id)

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
class AgentPersistenceConfig(BaseModel):
"""Nested persistence configuration for agents."""
agent_state: KVStoreReference
responses: ResponsesStoreReference
class MetaReferenceAgentsImplConfig(BaseModel):
persistence: AgentPersistenceConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"persistence": {
"agent_state": KVStoreReference(
backend="kv_default",
namespace="agents",
).model_dump(exclude_none=True),
"responses": ResponsesStoreReference(
backend="sql_default",
table_name="responses",
).model_dump(exclude_none=True),
}
}

View file

@ -0,0 +1,228 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
owner: User | None = None
identifier: str | None = None
type: str = "session"
class AgentInfo(AgentConfig):
created_at: datetime
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
self.kvstore = kvstore
self.policy = policy
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
user = get_authenticated_user()
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(UTC),
owner=user,
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError("create", session_info, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
raise SessionNotFoundError(session_id)
session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_if_accessible(session_id)
if session_info is None:
raise SessionNotFoundError(session_id)
session_info.vector_db_id = vector_db_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> list[Turn]:
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
# The kvstore does not guarantee order, so we sort by started_at
# to ensure consistent ordering of turns.
turns.sort(key=lambda t: t.started_at)
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
if not value:
return None
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
data = json.loads(value)
if "turn_id" in data:
continue
session_info = Session(**data)
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise SessionNotFoundError(session_id)
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,424 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.common.errors import (
InvalidConversationIdError,
)
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
extract_guardrail_ids,
)
logger = get_logger(name=__name__, category="openai_responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
class OpenAIResponsesImpl:
def __init__(
self,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
safety_api: Safety,
conversations_api: Conversations,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.conversations_api = conversations_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
vector_io_api=vector_io_api,
)
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
new_input_items.extend(previous_response.output)
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
return new_input_items
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
conversation: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion, tool context)
"""
tool_context = ToolContext(tools)
if previous_response_id:
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
)
all_input = await self._prepend_previous_response(input, previous_response)
if previous_response.messages:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
messages.extend(new_messages)
else:
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
tool_context.recover_tools_from_previous_response(previous_response)
elif conversation is not None:
conversation_items = await self.conversations_api.list_items(conversation, order="asc")
# Use stored messages as source of truth (like previous_response.messages)
stored_messages = await self.responses_store.get_conversation_messages(conversation)
all_input = input
if not conversation_items.data:
# First turn - just convert the new input
messages = await convert_response_input_to_chat_messages(input)
else:
if not stored_messages:
all_input = conversation_items.data
if isinstance(input, str):
all_input.append(
OpenAIResponseMessage(
role="user", content=[OpenAIResponseInputMessageContentText(text=input)]
)
)
else:
all_input.extend(input)
else:
all_input = input
messages = stored_messages or []
new_messages = await convert_response_input_to_chat_messages(all_input, previous_messages=messages)
messages.extend(new_messages)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(all_input)
return all_input, messages, tool_context
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return response_with_input.to_response_object()
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.responses_store.list_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
:returns: An ListOpenAIResponseInputItem.
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
messages=messages,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrailSpec] | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
if conversation is not None:
if previous_response_id is not None:
raise ValueError(
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
)
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
stream_gen = self._create_streaming_response(
input=input,
conversation=conversation,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
)
if stream:
return stream_gen
else:
final_response = None
final_event_type = None
failed_response = None
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if failed_response is not None:
error_message = (
failed_response.error.message
if failed_response and failed_response.error
else "Response stream failed without error details"
)
raise RuntimeError(f"OpenAI response failed: {error_message}")
if final_response is None:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id, conversation
)
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
# Structured outputs
response_format = await convert_response_text_to_chat_response_format(text)
ctx = ChatCompletionContext(
model=model,
messages=messages,
response_tools=tools,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
inputs=all_input,
)
# Create orchestrator and delegate streaming logic
response_id = f"resp_{uuid.uuid4()}"
created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator(
inference_api=self.inference_api,
ctx=ctx,
response_id=response_id,
created_at=created_at,
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,
instructions=instructions,
)
# Stream the response
final_response = None
failed_response = None
output_items = []
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type in {"response.completed", "response.incomplete"}:
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if stream_chunk.type == "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
# Store and sync before yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
if (
stream_chunk.type in {"response.completed", "response.incomplete"}
and final_response
and failed_response is None
):
messages_to_store = list(
filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
)
if store:
# TODO: we really should work off of output_items instead of "final_messages"
await self._store_response(
response=final_response,
input=all_input,
messages=messages_to_store,
)
if conversation:
await self._sync_response_to_conversation(conversation, input, output_items)
await self.responses_store.store_conversation_messages(conversation, messages_to_store)
yield stream_chunk
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)
async def _sync_response_to_conversation(
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
if isinstance(input, str):
conversation_items.append(
OpenAIResponseMessage(role="user", content=[OpenAIResponseInputMessageContentText(text=input)])
)
elif isinstance(input, list):
conversation_items.extend(input)
conversation_items.extend(output_items)
adapter = TypeAdapter(list[ConversationItem])
validated_items = adapter.validate_python(conversation_items)
await self.conversations_api.add_items(conversation_id, validated_items)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,449 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncIterator
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
OpenAIResponseObjectStreamResponseFileSearchCallInProgress,
OpenAIResponseObjectStreamResponseFileSearchCallSearching,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIImageURL,
OpenAIToolMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor:
def __init__(
self,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
vector_io_api: VectorIO,
):
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.vector_io_api = vector_io_api
async def execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number)
return
# Emit progress events for tool execution start
async for event_result in self._emit_progress_events(
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Execute the actual tool call
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Build result messages from tool execution
output_message, input_message = await self._build_result_messages(
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
)
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
)
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
unique_files = set()
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
# Get file_id from attributes if result_item.file_id is empty
file_id = result_item.file_id or (
result_item.attributes.get("document_id") if result_item.attributes else None
)
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
citation_instruction = ""
if unique_files:
citation_instruction = (
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
)
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
)
)
# handling missing attributes for old versions
citation_files = {}
for result in search_results:
file_id = result.file_id
if not file_id and result.attributes:
file_id = result.attributes.get("document_id")
filename = result.filename
if not filename and result.attributes:
filename = result.attributes.get("filename")
if not filename:
filename = "unknown"
citation_files[file_id] = filename
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
"citation_files": citation_files,
},
)
async def _emit_progress_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "knowledge_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function_name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event
if function_name == "knowledge_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
function_name: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]:
"""Execute the tool and return error exception and result."""
error_exc = None
result = None
try:
if mcp_tool_to_server and function_name in mcp_tool_to_server:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name]
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"tool_name": function_name,
}
async with tracing.span("invoke_mcp_tool", attributes):
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
async with tracing.span("knowledge_search", {}):
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
attributes = {
"tool_name": function_name,
}
async with tracing.span("invoke_tool", attributes):
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
return error_exc, result
async def _emit_completion_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "knowledge_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
async def _build_result_messages(
self,
function,
tool_call_id: str,
item_id: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
error_exc: Exception | None,
result: any,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]:
"""Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
# Build output message
if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=item_id,
arguments=function.arguments,
name=function.name,
server_label=mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=item_id,
status="completed",
)
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=item_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if result and "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id,
filename=doc_id,
text=text,
score=score,
attributes={},
)
)
if has_error:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
# Build input message
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message

View file

@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseTool,
OpenAIResponseToolMCP,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
class ToolExecutionResult(BaseModel):
"""Result of streaming tool execution."""
stream_event: OpenAIResponseObjectStream | None = None
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
citation_files: dict[str, str] | None = None
@dataclass
class ChatCompletionResult:
"""Result of processing streaming chat completion chunks."""
response_id: str
content: list[str]
tool_calls: dict[int, OpenAIChatCompletionToolCall]
created: int
model: str
finish_reason: str
message_item_id: str # For streaming events
tool_call_item_ids: dict[int, str] # For streaming events
content_part_emitted: bool # Tracking state
@property
def content_text(self) -> str:
"""Get joined content as string."""
return "".join(self.content)
@property
def has_tool_calls(self) -> bool:
"""Check if there are any tool calls."""
return bool(self.tool_calls)
class ToolContext(BaseModel):
"""Holds information about tools from this and (if relevant)
previous response in order to facilitate reuse of previous
listings where appropriate."""
# tools argument passed into current request:
current_tools: list[OpenAIResponseInputTool]
# reconstructed map of tool -> mcp server from previous response:
previous_tools: dict[str, OpenAIResponseInputToolMCP]
# reusable mcp-list-tools objects from previous response:
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
# tool arguments from current request that still need to be processed:
tools_to_process: list[OpenAIResponseInputTool]
def __init__(
self,
current_tools: list[OpenAIResponseInputTool] | None,
):
super().__init__(
current_tools=current_tools or [],
previous_tools={},
previous_tool_listings=[],
tools_to_process=current_tools or [],
)
def recover_tools_from_previous_response(
self,
previous_response: OpenAIResponseObject,
):
"""Determine which mcp_list_tools objects from previous response we can reuse."""
if self.current_tools and previous_response.tools:
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
for tool in previous_response.tools:
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
else:
tools_to_process.append(tool)
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
self.previous_tool_listings = [
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
return []
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
if isinstance(tool, OpenAIResponseInputToolWebSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFileSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFunction):
return tool
if isinstance(tool, OpenAIResponseInputToolMCP):
return OpenAIResponseToolMCP(
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
return [convert_tool(tool) for tool in self.current_tools]
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
tool_context: ToolContext | None
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def __init__(
self,
model: str,
messages: list[OpenAIMessageParam],
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
tool_context: ToolContext,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
model=model,
messages=messages,
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = {
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
}
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments)
return self.approval_responses.get(request.id, None) if request else None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests:
if request.name == tool_name and request.arguments == arguments:
return request
return None
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.tool_context:
return []
return self.tool_context.available_tools()

View file

@ -0,0 +1,365 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import re
import uuid
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import Safety
async def convert_chat_choice_to_response_message(
choice: OpenAIChoice,
citation_files: dict[str, str] | None = None,
*,
message_id: str | None = None,
) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
return OpenAIResponseMessage(
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
previous_messages: list[OpenAIMessageParam] | None = None,
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
:param input: The input to convert
:param previous_messages: Optional previous messages to check for function_call references
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
# extract all OpenAIResponseInputFunctionToolCallOutput items
# so their corresponding OpenAIToolMessageParam instances can
# be added immediately following the corresponding
# OpenAIAssistantMessageParam
tool_call_results = {}
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
# skip as these have been extracted and inserted in order
pass
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
if input_item.call_id in tool_call_results:
messages.append(tool_call_results[input_item.call_id])
del tool_call_results[input_item.call_id]
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
# Skip user messages that duplicate the last user message in previous_messages
# This handles cases where input includes context for function_call_outputs
if previous_messages and input_item.role == "user":
last_user_msg = None
for msg in reversed(previous_messages):
if isinstance(msg, OpenAIUserMessageParam):
last_user_msg = msg
break
if last_user_msg:
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
if len(tool_call_results):
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
previous_call_ids = _extract_tool_call_ids(previous_messages)
for call_id in list(tool_call_results.keys()):
if call_id in previous_call_ids:
# Valid: this output references a call from previous messages
# Add the tool message
messages.append(tool_call_results[call_id])
del tool_call_results[call_id]
# If still have unpaired outputs, error
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
"""Extract all tool_call IDs from messages."""
call_ids = set()
for msg in messages:
if isinstance(msg, OpenAIAssistantMessageParam):
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call in tool_calls:
# tool_call is a Pydantic model, use attribute access
call_ids.add(tool_call.id)
return call_ids
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
if not text.format or text.format["type"] == "text":
return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
"""Get the appropriate OpenAI message parameter type for a given role."""
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
def _extract_citations_from_text(
text: str, citation_files: dict[str, str]
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
"""Extract citation markers from text and create annotations
Args:
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
citation_files: Dictionary mapping file_id to filename
Returns:
Tuple of (annotations_list, clean_text_without_markers)
"""
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
annotations = []
parts = []
total_len = 0
last_end = 0
for m in file_id_regex.finditer(text):
# segment before the marker
prefix = text[last_end : m.start()]
# drop one space if it exists (since marker is at sentence end)
if prefix.endswith(" "):
prefix = prefix[:-1]
parts.append(prefix)
total_len += len(prefix)
fid = m.group(1)
if fid in citation_files:
annotations.append(
OpenAIResponseAnnotationFileCitation(
file_id=fid,
filename=citation_files[fid],
index=total_len, # index points to punctuation
)
)
last_end = m.end()
parts.append(text[last_end:])
cleaned_text = "".join(parts)
return annotations, cleaned_text
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
"""Run guardrails against messages and return violation message if blocked."""
if not messages:
return None
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.routing_table.list_shields()
for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
if matching_shields:
model_id = matching_shields[0].provider_resource_id
model_ids.append(model_id)
else:
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
responses = await asyncio.gather(*guardrail_tasks)
for response in responses:
for result in response.results:
if result.flagged:
message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories:
message += f" (flagged for: {', '.join(flagged_categories)})"
if violation_type:
message += f" (violation type: {', '.join(violation_type)})"
return message
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
if not guardrails:
return []
guardrail_ids = []
for guardrail in guardrails:
if isinstance(guardrail, str):
guardrail_ids.append(guardrail)
elif isinstance(guardrail, ResponseGuardrailSpec):
guardrail_ids.append(guardrail.type)
else:
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
return guardrail_ids

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818
def __init__(self, violation: SafetyViolation):
self.violation = violation
super().__init__(violation.user_message)
class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: list[str] | None = None,
output_shields: list[str] | None = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(
shield_id=identifier,
messages=messages,
params={},
)
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
for identifier, response in zip(identifiers, responses, strict=False):
if not response.violation:
continue
violation = response.violation
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
log.warning(f"[Warn]{identifier} raised a warning")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.core.datatypes import AccessRule, Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from .batches import ReferenceBatchesImpl
from .config import ReferenceBatchesImplConfig
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
kvstore = await kvstore_impl(config.kvstore)
inference_api: Inference | None = deps.get(Api.inference)
files_api: Files | None = deps.get(Api.files)
models_api: Models | None = deps.get(Api.models)
if inference_api is None:
raise ValueError("Inference API is required but not provided in dependencies")
if files_api is None:
raise ValueError("Files API is required but not provided in dependencies")
if models_api is None:
raise ValueError("Models API is required but not provided in dependencies")
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
await impl.initialize()
return impl

View file

@ -0,0 +1,679 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import hashlib
import itertools
import json
import time
import uuid
from io import BytesIO
from typing import Any, Literal
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.models import Models
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from .config import ReferenceBatchesImplConfig
BATCH_PREFIX = "batch:"
logger = get_logger(__name__)
class AsyncBytesIO:
"""
Async-compatible BytesIO wrapper to allow async file-like operations.
We use this when uploading files to the Files API, as it expects an
async file-like object.
"""
def __init__(self, data: bytes):
self._buffer = BytesIO(data)
async def read(self, n=-1):
return self._buffer.read(n)
async def seek(self, pos, whence=0):
return self._buffer.seek(pos, whence)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._buffer.close()
def __getattr__(self, name):
return getattr(self._buffer, name)
class BatchRequest(BaseModel):
line_num: int
custom_id: str
method: str
url: str
body: dict[str, Any]
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
"""Convert a message dictionary to OpenAIMessageParam based on role."""
role = msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**msg)
elif role == "system":
return OpenAISystemMessageParam(**msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**msg)
elif role == "tool":
return OpenAIToolMessageParam(**msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**msg)
else:
raise ValueError(f"Unknown message role: {role}")
class ReferenceBatchesImpl(Batches):
"""Reference implementation of the Batches API.
This implementation processes batch files by making individual requests
to the inference API and generates output files with results.
"""
def __init__(
self,
config: ReferenceBatchesImplConfig,
inference_api: Inference,
files_api: Files,
models_api: Models,
kvstore: KVStore,
) -> None:
self.config = config
self.kvstore = kvstore
self.inference_api = inference_api
self.files_api = files_api
self.models_api = models_api
self._processing_tasks: dict[str, asyncio.Task] = {}
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
self._update_batch_lock = asyncio.Lock()
# this is to allow tests to disable background processing
self.process_batches = True
async def initialize(self) -> None:
# TODO: start background processing of existing tasks
pass
async def shutdown(self) -> None:
"""Shutdown the batches provider."""
if self._processing_tasks:
# don't cancel tasks - just let them stop naturally on shutdown
# cancelling would mark batches as "cancelled" in the database
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
This implementation provides optional idempotency: when an idempotency key
(idempotency_key) is provided, a deterministic ID is generated based on the input
parameters. If a batch with the same parameters already exists, it will be
returned instead of creating a duplicate. Without an idempotency key,
each request creates a new batch with a unique ID.
Args:
input_file_id: The ID of an uploaded file containing requests for the batch.
endpoint: The endpoint to be used for all requests in the batch.
completion_window: The time window within which the batch should be processed.
metadata: Optional metadata for the batch.
idempotency_key: Optional idempotency key for enabling idempotent behavior.
Returns:
The created or existing batch object.
"""
# Error handling by levels -
# 0. Input param handling, results in 40x errors before processing, e.g.
# - Wrong completion_window
# - Invalid metadata types
# - Unknown endpoint
# -> no batch created
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
# - input_file_id missing
# - invalid json in file
# - missing custom_id, method, url, body
# - invalid model
# - streaming
# -> batch created, validation sends to failed status
# 2. Processing errors, result in error_file_id entries, e.g.
# - Any error returned from inference endpoint
# -> batch created, goes to completed status
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
raise ValueError(
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
logger.info(f"Returning existing batch with ID: {batch_id}")
return existing_batch
except ResourceNotFoundError:
# Batch doesn't exist, continue with creation
pass
current_time = int(time.time())
batch = BatchObject(
id=batch_id,
object="batch",
endpoint=endpoint,
input_file_id=input_file_id,
completion_window=completion_window,
status="validating",
created_at=current_time,
metadata=metadata,
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
logger.info(f"Created new batch with ID: {batch_id}")
if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))
self._processing_tasks[batch_id] = task
return batch
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress."""
batch = await self.retrieve_batch(batch_id)
if batch.status in ["cancelled", "cancelling"]:
return batch
if batch.status in ["completed", "failed", "expired"]:
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
if batch_id in self._processing_tasks:
self._processing_tasks[batch_id].cancel()
# note: task removal and status="cancelled" handled in finally block of _process_batch
return await self.retrieve_batch(batch_id)
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""
List all batches, eventually only for the current user.
With no notion of user, we return all batches.
"""
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
batches = []
for batch_data in batch_values:
if batch_data:
batches.append(BatchObject.model_validate_json(batch_data))
batches.sort(key=lambda b: b.created_at, reverse=True)
start_idx = 0
if after:
for i, batch in enumerate(batches):
if batch.id == after:
start_idx = i + 1
break
page_batches = batches[start_idx : start_idx + limit]
has_more = (start_idx + limit) < len(batches)
first_id = page_batches[0].id if page_batches else None
last_id = page_batches[-1].id if page_batches else None
return ListBatchesResponse(
data=page_batches,
first_id=first_id,
last_id=last_id,
has_more=has_more,
)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch."""
batch_data = await self.kvstore.get(f"batch:{batch_id}")
if not batch_data:
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
return BatchObject.model_validate_json(batch_data)
async def _update_batch(self, batch_id: str, **updates) -> None:
"""Update batch fields in kvstore."""
async with self._update_batch_lock:
try:
batch = await self.retrieve_batch(batch_id)
# batch processing is async. once cancelling, only allow "cancelled" status updates
if batch.status == "cancelling" and updates.get("status") != "cancelled":
logger.info(
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
)
return
if "errors" in updates:
updates["errors"] = updates["errors"].model_dump()
batch_dict = batch.model_dump()
batch_dict.update(updates)
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
except Exception as e:
logger.error(f"Failed to update batch {batch_id}: {e}")
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
"""
Read & validate input, return errors and valid input.
Validation of
- input_file_id existance
- valid json
- custom_id, method, url, body presence and valid
- no streaming
"""
requests: list[BatchRequest] = []
errors: list[BatchError] = []
try:
await self.files_api.openai_retrieve_file(batch.input_file_id)
except Exception:
errors.append(
BatchError(
code="invalid_request",
line=None,
message=f"Cannot find file {batch.input_file_id}.",
param="input_file_id",
)
)
return errors, requests
# TODO(SECURITY): do something about large files
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
file_content = file_content_response.body.decode("utf-8")
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
if line.strip(): # skip empty lines
try:
request = json.loads(line)
if not isinstance(request, dict):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message="Each line must be a JSON dictionary object",
)
)
continue
valid = True
for param, expected_type, type_string in [
("custom_id", str, "string"),
("method", str, "string"),
("url", str, "string"),
("body", dict, "JSON dictionary object"),
]:
if param not in request:
errors.append(
BatchError(
code="missing_required_parameter",
line=line_num,
message=f"Missing required parameter: {param}",
param=param,
)
)
valid = False
elif not isinstance(request[param], expected_type):
param_name = "URL" if param == "url" else param.capitalize()
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param_name} must be a {type_string}",
param=param,
)
)
valid = False
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
errors.append(
BatchError(
code="invalid_url",
line=line_num,
message="URL provided for this request does not match the batch endpoint",
param="url",
)
)
valid = False
if (body := request.get("body")) and isinstance(body, dict):
if body.get("stream", False):
errors.append(
BatchError(
code="streaming_unsupported",
line=line_num,
message="Streaming is not supported in batch processing",
param="body.stream",
)
)
valid = False
if batch.endpoint == "/v1/chat/completions":
required_params: list[tuple[str, Any, str]] = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
elif batch.endpoint == "/v1/completions":
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
else: # /v1/embeddings
required_params = [
("model", str, "a string"),
("input", (str, list), "a string or array of strings"),
]
for param, expected_type, type_string in required_params:
if param not in body:
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} parameter is required",
param=f"body.{param}",
)
)
valid = False
elif not isinstance(body[param], expected_type):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} must be {type_string}",
param=f"body.{param}",
)
)
valid = False
if "model" in body and isinstance(body["model"], str):
try:
await self.models_api.get_model(body["model"])
except Exception:
errors.append(
BatchError(
code="model_not_found",
line=line_num,
message=f"Model '{body['model']}' does not exist or is not supported",
param="body.model",
)
)
valid = False
if valid:
assert isinstance(url, str), "URL must be a string" # for mypy
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
requests.append(
BatchRequest(
line_num=line_num,
url=url,
method=request["method"],
custom_id=request["custom_id"],
body=body,
),
)
except json.JSONDecodeError:
errors.append(
BatchError(
code="invalid_json_line",
line=line_num,
message="This line is not parseable as valid JSON.",
)
)
return errors, requests
async def _process_batch(self, batch_id: str) -> None:
"""Background task to process a batch of requests."""
try:
logger.info(f"Starting batch processing for {batch_id}")
async with self._batch_semaphore: # semaphore to limit concurrency
logger.info(f"Acquired semaphore for batch {batch_id}")
await self._process_batch_impl(batch_id)
except asyncio.CancelledError:
logger.info(f"Batch processing cancelled for {batch_id}")
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
except Exception as e:
logger.error(f"Batch processing failed for {batch_id}: {e}")
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
)
finally:
self._processing_tasks.pop(batch_id, None)
async def _process_batch_impl(self, batch_id: str) -> None:
"""Implementation of batch processing logic."""
errors: list[BatchError] = []
batch = await self.retrieve_batch(batch_id)
errors, requests = await self._validate_input(batch)
if errors:
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
return
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
total_requests = len(requests)
await self._update_batch(
batch_id,
status="in_progress",
request_counts={"total": total_requests, "completed": 0, "failed": 0},
)
error_results = []
success_results = []
completed_count = 0
failed_count = 0
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
async with asyncio.TaskGroup() as tg:
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
for result in chunk_results:
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
failed_count += 1
error_results.append(result)
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
completed_count += 1
success_results.append(result)
else: # unexpected result
failed_count += 1
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
await self._update_batch(
batch_id,
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
)
if errors:
await self._update_batch(
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
)
return
try:
output_file_id = await self._create_output_file(batch_id, success_results, "success")
await self._update_batch(batch_id, output_file_id=output_file_id)
error_file_id = await self._create_output_file(batch_id, error_results, "error")
await self._update_batch(batch_id, error_file_id=error_file_id)
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
logger.info(
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
)
except Exception as e:
# note: errors is empty at this point, so we don't lose anything by ignoring it
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
)
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
"""Process a single request from the batch."""
request_id = f"batch_req_{batch_id}_{request.line_num}"
try:
# TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
chat_response = await self.inference_api.openai_chat_completion(chat_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
elif request.url == "/v1/completions":
completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
completion_response = await self.inference_api.openai_completion(completion_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
"Completion response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id,
"body": completion_response.model_dump_json(),
},
}
else: # /v1/embeddings
embeddings_response = await self.inference_api.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(**request.body)
)
assert hasattr(embeddings_response, "model_dump_json"), (
"Embeddings response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": embeddings_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {
"id": request_id,
"custom_id": request.custom_id,
"error": {"type": "request_failed", "message": str(e)},
}
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
"""
Create an output file with batch results.
This function filters results based on the specified file_type
and uploads the file to the Files API.
"""
output_lines = [json.dumps(result) for result in results]
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
return uploaded_file.id

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.core.storage.datatypes import KVStoreReference
class ReferenceBatchesImplConfig(BaseModel):
"""Configuration for the Reference Batches implementation."""
kvstore: KVStoreReference = Field(
description="Configuration for the key-value store backend.",
)
max_concurrent_batches: int = Field(
default=1,
description="Maximum number of concurrent batches to process simultaneously.",
ge=1,
)
max_concurrent_requests_per_batch: int = Field(
default=10,
description="Maximum number of concurrent requests to process per batch.",
ge=1,
)
# TODO: add a max requests per second rate limiter
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict:
return {
"kvstore": KVStoreReference(
backend="kv_default",
namespace="batches",
).model_dump(exclude_none=True),
}

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: LocalFSDatasetIOConfig,
_deps: dict[str, Any],
):
from .datasetio import LocalFSDatasetIOImpl
impl = LocalFSDatasetIOImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.core.storage.datatypes import KVStoreReference
class LocalFSDatasetIOConfig(BaseModel):
kvstore: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": KVStoreReference(
backend="kv_default",
namespace="datasetio::localfs",
).model_dump(exclude_none=True)
}

View file

@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:"
class PandasDataframeDataset:
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
self.df = None
def __len__(self) -> int:
assert self.df is not None, "Dataset not loaded. Please call .load() first"
return len(self.df)
def __getitem__(self, idx):
assert self.df is not None, "Dataset not loaded. Please call .load() first"
if isinstance(idx, slice):
return self.df.iloc[idx].to_dict(orient="records")
else:
return self.df.iloc[idx].to_dict()
async def load(self) -> None:
if self.df is not None:
return
if self.dataset_def.source.type == "uri":
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
import pandas
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
self.kvstore = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
self.dataset_infos[dataset.identifier] = dataset
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset_def: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
await dataset_impl.load()
records = dataset_impl.df.to_dict("records")
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
import pandas
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
await dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl
impl = MetaReferenceEvalImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)
await impl.initialize()
return impl

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.core.storage.datatypes import KVStoreReference
class MetaReferenceEvalConfig(BaseModel):
kvstore: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": KVStoreReference(
backend="kv_default",
namespace="eval",
).model_dump(exclude_none=True)
}

View file

@ -0,0 +1,259 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
from tqdm import tqdm
from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
class MetaReferenceEvalImpl(
Eval,
BenchmarksProtocolPrivate,
):
def __init__(
self,
config: MetaReferenceEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agents_api = agents_api
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.benchmarks = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark)
self.benchmarks[benchmark.identifier] = benchmark
async def shutdown(self) -> None: ...
async def register_benchmark(self, task_def: Benchmark) -> None:
# Store in kvstore
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set(
key=key,
value=task_def.model_dump_json(),
)
self.benchmarks[task_def.identifier] = task_def
async def unregister_benchmark(self, benchmark_id: str) -> None:
if benchmark_id in self.benchmarks:
del self.benchmarks[benchmark_id]
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
await self.kvstore.delete(key)
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
generations = []
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
session_id = session_create_response.session_id
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=input_messages,
stream=True,
)
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context
memory_rag_context = None
for step in final_event.turn.steps:
if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(x.text for x in tool_response.content)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
generations.append(agent_generation)
return generations
async def _run_model_generation(
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
generations = []
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
if candidate.sampling_params.stop:
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
params = OpenAICompletionRequestWithExtraBody(
model=candidate.model,
prompt=input_content,
**sampling_params,
)
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
params = OpenAIChatCompletionRequestWithExtraBody(
model=candidate.model,
messages=messages,
**sampling_params,
)
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")
return generations
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config)
elif candidate.type == "model":
generations = await self._run_model_generation(input_rows, benchmark_config)
else:
raise ValueError(f"Invalid candidate type: {candidate.type}")
# scoring with generated_answer
score_input_rows = [
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
]
if benchmark_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = dict.fromkeys(scoring_functions)
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return Job(job_id=job_id, status=JobStatus.completed)
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
return self.jobs[job_id]

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import AccessRule, Api
from .config import LocalfsFilesImplConfig
from .files import LocalfsFilesImpl
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
impl = LocalfsFilesImpl(config, policy)
await impl.initialize()
return impl

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.core.storage.datatypes import SqlStoreReference
class LocalfsFilesImplConfig(BaseModel):
storage_dir: str = Field(
description="Directory to store uploaded files",
)
metadata_store: SqlStoreReference = Field(
description="SQL store configuration for file metadata",
)
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
"metadata_store": SqlStoreReference(
backend="sql_default",
table_name="files_metadata",
).model_dump(exclude_none=True),
}

View file

@ -0,0 +1,219 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from pathlib import Path
from typing import Annotated
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import LocalfsFilesImplConfig
logger = get_logger(name=__name__, category="files")
class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
self.config = config
self.policy = policy
self.sql_store: AuthorizedSqlStore | None = None
async def initialize(self) -> None:
"""Initialize the files provider by setting up storage directory and metadata database."""
# Create storage directory if it doesn't exist
storage_path = Path(self.config.storage_dir)
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
"file_path": ColumnType.STRING, # Path to actual file on disk
},
)
async def shutdown(self) -> None:
pass
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID."""
return Path(self.config.storage_dir) / file_id
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
"""Look up a OpenAIFileObject and filesystem path from its ID."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
file_path = Path(row.pop("file_path"))
return OpenAIFileObject(**row), file_path
# OpenAI Files API Implementation
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
if expires_after is not None:
logger.warning(
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
)
file_id = self._generate_file_id()
file_path = self._get_file_path(file_id)
content = await file.read()
file_size = len(content)
with open(file_path, "wb") as f:
f.write(content)
created_at = int(time.time())
expires_at = created_at + self.config.ttl_secs
await self.sql_store.insert(
"openai_files",
{
"id": file_id,
"filename": file.filename or "uploaded_file",
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
"file_path": file_path.as_posix(),
},
)
return OpenAIFileObject(
id=file_id,
filename=file.filename or "uploaded_file",
purpose=purpose,
bytes=file_size,
created_at=created_at,
expires_at=expires_at,
)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
"""Returns a list of files that belong to the user's organization."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
if not order:
order = Order.desc
where_conditions = {}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [
OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
"""Returns information about a specific file."""
file_obj, _ = await self._lookup_file_id(file_id)
return file_obj
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
"""Delete a file."""
# Delete physical file
_, file_path = await self._lookup_file_id(file_id)
if file_path.exists():
file_path.unlink()
# Delete metadata from database
assert self.sql_store is not None, "Files provider not initialized"
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(
id=file_id,
deleted=True,
)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
"""Returns the contents of the specified file."""
# Read file content
file_obj, file_path = await self._lookup_file_id(file_id)
if not file_path.exists():
logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
await self.openai_delete_file(file_id)
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
# Return as binary response with appropriate content type
return Response(
content=file_path.read_bytes(),
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import MetaReferenceInferenceConfig
async def get_provider_impl(
config: MetaReferenceInferenceConfig,
_deps: dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.core.utils.model_utils import model_local_dir
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel):
# this is a placeholder to indicate inference model id
# the actual inference model id is dtermined by the moddel id in the request
# Note: you need to register the model before using it for inference
# models in the resouce list in the run.yaml config will be registered automatically
model: str | None = None
torch_seed: int | None = None
max_seq_len: int = 4096
max_batch_size: int = 1
model_parallel_size: int | None = None
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library.
create_distributed_process_group: bool = True
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: str | None = None
quantization: QuantizationConfig | None = None
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
**kwargs,
) -> dict[str, Any]:
return {
"model": model,
"checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
"max_batch_size": max_batch_size,
"max_seq_len": max_seq_len,
}

View file

@ -0,0 +1,211 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import math
from collections.abc import Generator
from typing import Optional
import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import (
GreedySamplingStrategy,
JsonSchemaResponseFormat,
ResponseFormat,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
get_default_tool_prompt_format,
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig
from .inference import resolve_model
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: torch.Tensor | None = None
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: ResponseFormat | None,
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
if not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.json_schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,
tokenizer.stop_tokens,
)
token_enforcer = TokenEnforcer(data, parser)
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens
def _infer_sampling_params(sampling_params: SamplingParams):
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
temperature = 0.0
top_p = 1.0
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
temperature = sampling_params.strategy.temperature or 1.0
top_p = sampling_params.strategy.top_p or 1.0
else:
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
return temperature, top_p
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
tool_config = request.tool_config
if tool_config is not None and tool_config.tool_prompt_format is not None:
return tool_config.tool_prompt_format
else:
return get_default_tool_prompt_format(request.model)
class LlamaGenerator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
self.inner_generator = cls.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
)
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
)

View file

@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
self.model_id = None
self.llama_model = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return None
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
else resolve_model(model.identifier)
)
if llama_model is None:
raise ValueError(
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_hf_repo_model_entry(
llama_model.descriptor(),
llama_model.core_model_id.value,
)
],
)
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
# TODO: what is this?! you can't really specify skipping via model metadata
# kill this madness
if "skip_load" in model.metadata and model.metadata["skip_load"]:
return model
await self.load_model(model.identifier, llama_model)
return model
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=llama_builder_fn,
builder_params=builder_params,
formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance())
if llama_model.model_family == ModelFamily.llama4
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
),
)
self.generator.start()
else:
self.generator = llama_builder_fn(*builder_params)
self.model_id = model_id
self.llama_model = llama_model
log.info("Warming up...")
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
)
log.info("Warmed up!")
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable, Generator
from copy import deepcopy
from functools import partial
from typing import Any
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .parallel_utils import ModelParallelProcessGroup
class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
else:
raise ValueError(f"Unexpected task type {task[0]}")
def init_model_cb(
builder_fn: Callable,
params: list[Any],
):
llama = builder_fn(*params)
return ModelRunner(llama)
class LlamaModelParallelGenerator:
"""
This abstraction exists so
- we can run model parallel code without needing to run the CLIs via torchrun
- this also enables use model parallel code within a notebook context.
A Context Manager is used to ensure that the model parallel process is started and stopped
correctly. This does make the ergonomics a little awkward, because it isn't immediately
clear at the callsite why we need to use a context manager.
"""
def __init__(
self,
model_parallel_size: int,
builder_fn: Callable,
builder_params: list[Any],
formatter: Llama3ChatFormat | Llama4ChatFormat,
):
self.model_parallel_size = model_parallel_size
self.builder_fn = builder_fn
self.builder_params = builder_params
self.formatter = formatter
def start(self):
self.__enter__()
def stop(self):
self.__exit__(None, None, None)
def __enter__(self):
self.group = ModelParallelProcessGroup(
self.model_parallel_size,
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
)
self.group.start()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen

View file

@ -0,0 +1,363 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import json
import multiprocessing
import os
import tempfile
import time
import uuid
from collections.abc import Callable, Generator
from enum import Enum
from typing import Annotated, Literal
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = get_logger(name=__name__, category="inference")
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
end_sentinel = "end_sentinel"
cancel_sentinel = "cancel_sentinel"
task_request = "task_request"
task_response = "task_response"
exception_response = "exception_response"
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: list[GenerationResult]
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
error: str
ProcessingMessage = (
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
)
class ProcessingMessageWrapper(BaseModel):
payload: Annotated[
ProcessingMessage,
Field(discriminator="type"),
]
def mp_rank_0() -> bool:
return bool(get_model_parallel_rank() == 0)
def encode_msg(msg: ProcessingMessage) -> bytes:
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
def retrieve_requests(reply_socket_url: str):
if mp_rank_0():
context = zmq.Context()
reply_socket = context.socket(zmq.ROUTER)
reply_socket.connect(reply_socket_url)
while True:
client_id, obj = maybe_get_work(reply_socket)
if obj is None:
time.sleep(0.01)
continue
ready_response = ReadyResponse()
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
break
def send_obj(obj: ProcessingMessage):
reply_socket.send_multipart([client_id, encode_msg(obj)])
while True:
tasks: list[ProcessingMessage | None] = [None]
if mp_rank_0():
client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None:
task = maybe_parse_message(maybe_task_json)
# there is still an unknown unclean GeneratorExit happening resulting in a
# cancel sentinel getting queued _after_ we have finished sending everything :/
# kind of a hack this is :/
if task is not None and not isinstance(task, CancelSentinel):
tasks = [task]
torch.distributed.broadcast_object_list(
tasks,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
task = tasks[0]
if task is None:
time.sleep(0.1)
else:
try:
out = yield task
if out is None:
break
for obj in out:
updates: list[ProcessingMessage | None] = [None]
if mp_rank_0():
_, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json)
if isinstance(update, CancelSentinel):
updates = [update]
else:
# only send the update if it's not cancelled otherwise the object sits in the socket
# and gets pulled in the next request lol
send_obj(TaskResponse(result=obj))
torch.distributed.broadcast_object_list(
updates,
src=get_model_parallel_src_rank(),
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
log.info("quitting generation loop because request was cancelled")
break
if mp_rank_0():
send_obj(EndSentinel())
except Exception as e:
log.exception("exception in generation loop")
if mp_rank_0():
send_obj(ExceptionResponse(error=str(e)))
if mp_rank_0():
send_obj(EndSentinel())
def maybe_get_work(sock: zmq.Socket):
message = None
client_id = None
try:
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
message = obj.decode("utf-8")
except zmq.ZMQError as e:
if e.errno != zmq.EAGAIN:
raise e
return client_id, message
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
if maybe_json is None:
return None
try:
return parse_message(maybe_json)
except json.JSONDecodeError:
return None
except ValueError:
return None
def parse_message(json_str: str) -> ProcessingMessage:
data = json.loads(json_str)
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
def worker_process_entrypoint(
reply_socket_url: str,
init_model_cb: Callable,
) -> None:
model = init_model_cb()
torch.distributed.barrier()
time.sleep(1)
# run the requests co-routine which retrieves requests from the socket
# and sends responses (we provide) back to the caller
req_gen = retrieve_requests(reply_socket_url)
result = None
while True:
try:
task = req_gen.send(result)
if isinstance(task, EndSentinel):
break
assert isinstance(task, TaskRequest), task
result = model(task.task)
except StopIteration:
break
log.info("[debug] worker process done")
def launch_dist_group(
reply_socket_url: str,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
max_nodes=1,
min_nodes=1,
nproc_per_node=model_parallel_size,
start_method="fork",
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file", "timeout": 90},
max_restarts=0,
monitor_interval=1,
run_id=str(uuid.uuid4()),
)
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
reply_socket_url,
init_model_cb,
)
def start_model_parallel_process(
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
context = zmq.Context()
request_socket = context.socket(zmq.DEALER)
# Binding the request socket to a random port
request_socket.bind("tcp://127.0.0.1:0")
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("spawn")
process = ctx.Process(
target=launch_dist_group,
args=(
main_process_url,
model_parallel_size,
init_model_cb,
),
kwargs=kwargs,
)
process.start()
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(encode_msg(ReadyRequest()))
_response = request_socket.recv()
log.info("Loaded model...")
return request_socket, process
class ModelParallelProcessGroup:
def __init__(
self,
model_parallel_size: int,
init_model_cb: Callable,
**kwargs,
):
self.model_parallel_size = model_parallel_size
self.init_model_cb = init_model_cb
self.started = False
self.running = False
def start(self):
assert not self.started, "process group already started"
self.request_socket, self.process = start_model_parallel_process(
self.model_parallel_size,
self.init_model_cb,
)
self.started = True
def stop(self):
assert self.started, "process group not started"
if self.process.is_alive():
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
self.process.join()
self.started = False
def run_inference(
self,
req: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
],
) -> Generator:
assert not self.running, "inference already running"
self.running = True
try:
self.request_socket.send(encode_msg(TaskRequest(task=req)))
while True:
obj_json = self.request_socket.recv()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
if isinstance(obj, ExceptionResponse):
log.error(f"[debug] got exception {obj.error}")
raise Exception(obj.error)
if isinstance(obj, TaskResponse):
yield obj.result
except GeneratorExit:
self.request_socket.send(encode_msg(CancelSentinel()))
while True:
obj_json = self.request_socket.send()
obj = parse_message(obj_json)
if isinstance(obj, EndSentinel):
break
finally:
self.running = False

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps: dict[str, Any],
):
from .sentence_transformers import SentenceTransformersInferenceImpl
impl = SentenceTransformersInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {}

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return [
Model(
identifier="nomic-ai/nomic-embed-text-v1.5",
provider_resource_id="nomic-ai/nomic-embed-text-v1.5",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 768,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model:
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -0,0 +1,550 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 56;
objects = {
/* Begin PBXBuildFile section */
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5CADC7192CA471CC007662D2 /* LlamaStackClient */; };
5CAF3DD82CA485740029CD2B /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5CAF3DD72CA485740029CD2B /* LlamaStackClient */; };
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */ = {isa = PBXBuildFile; fileRef = 5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */; settings = {ATTRIBUTES = (Public, ); }; };
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6742CA1F45800E958D0 /* executorch_debug */; };
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; };
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */; };
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */; };
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */; };
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */; };
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6922CA1F7D000E958D0 /* Stencil */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 036CAF9D2BB1444500D6C2D5;
remoteInfo = LLaMA;
};
5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 03729ED52BB1F8DE00152F2E;
remoteInfo = LLaMARunner;
};
5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6982CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmark;
};
5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = 5CCBC6992CA2036A00E958D0;
remoteInfo = LLaMAPerfBenchmarkTests;
};
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LocalInferenceImpl.framework; sourceTree = BUILT_PRODUCTS_DIR; };
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LocalInference.h; sourceTree = "<group>"; };
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = LLaMA.xcodeproj; path = "executorch/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj"; sourceTree = "<group>"; };
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PromptTemplate.swift; sourceTree = "<group>"; };
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LocalInference.swift; sourceTree = "<group>"; };
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Parsing.swift; sourceTree = "<group>"; };
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SystemPrompts.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
5CCBC6052CA1F04A00E958D0 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */,
5CAF3DD82CA485740029CD2B /* LlamaStackClient in Frameworks */,
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */,
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */,
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
5CCBC5FE2CA1F04A00E958D0 = {
isa = PBXGroup;
children = (
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */,
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */,
5CCBC6092CA1F04A00E958D0 /* Products */,
5CCBC6852CA1F64A00E958D0 /* Frameworks */,
);
sourceTree = "<group>";
};
5CCBC6092CA1F04A00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */ = {
isa = PBXGroup;
children = (
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */,
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */,
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */,
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */,
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */,
);
path = LocalInferenceImpl;
sourceTree = "<group>";
};
5CCBC6772CA1F63F00E958D0 /* Products */ = {
isa = PBXGroup;
children = (
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */,
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */,
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */,
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
5CCBC6852CA1F64A00E958D0 /* Frameworks */ = {
isa = PBXGroup;
children = (
);
name = Frameworks;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
5CCBC6032CA1F04A00E958D0 /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */ = {
isa = PBXNativeTarget;
buildConfigurationList = 5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */;
buildPhases = (
5CCBC6032CA1F04A00E958D0 /* Headers */,
5CCBC6042CA1F04A00E958D0 /* Sources */,
5CCBC6052CA1F04A00E958D0 /* Frameworks */,
5CCBC6062CA1F04A00E958D0 /* Resources */,
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */,
);
buildRules = (
);
dependencies = (
);
name = LocalInferenceImpl;
packageProductDependencies = (
5CCBC6742CA1F45800E958D0 /* executorch_debug */,
5CCBC6922CA1F7D000E958D0 /* Stencil */,
5CADC7192CA471CC007662D2 /* LlamaStackClient */,
5CAF3DD72CA485740029CD2B /* LlamaStackClient */,
);
productName = LocalInferenceProvider;
productReference = 5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */;
productType = "com.apple.product-type.framework";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
5CCBC5FF2CA1F04A00E958D0 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastUpgradeCheck = 1540;
TargetAttributes = {
5CCBC6072CA1F04A00E958D0 = {
CreatedOnToolsVersion = 15.4;
LastSwiftMigration = 1540;
};
};
};
buildConfigurationList = 5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 5CCBC5FE2CA1F04A00E958D0;
packageReferences = (
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */,
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */,
5CAF3DD62CA485740029CD2B /* XCRemoteSwiftPackageReference "llama-stack-client-swift" */,
);
productRefGroup = 5CCBC6092CA1F04A00E958D0 /* Products */;
projectDirPath = "";
projectReferences = (
{
ProductGroup = 5CCBC6772CA1F63F00E958D0 /* Products */;
ProjectRef = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
},
);
projectRoot = "";
targets = (
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */,
);
};
/* End PBXProject section */
/* Begin PBXReferenceProxy section */
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMA.app;
remoteRef = 5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */ = {
isa = PBXReferenceProxy;
fileType = wrapper.framework;
path = LLaMARunner.framework;
remoteRef = 5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */ = {
isa = PBXReferenceProxy;
fileType = wrapper.application;
path = LLaMAPerfBenchmark.app;
remoteRef = 5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */ = {
isa = PBXReferenceProxy;
fileType = wrapper.cfbundle;
path = LLaMAPerfBenchmarkTests.xctest;
remoteRef = 5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
/* End PBXReferenceProxy section */
/* Begin PBXResourcesBuildPhase section */
5CCBC6062CA1F04A00E958D0 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
5CCBC6042CA1F04A00E958D0 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */,
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */,
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */,
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
5CCBC60D2CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Debug;
};
5CCBC60E2CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Release;
};
5CCBC6102CA1F04A00E958D0 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
5CCBC6112CA1F04A00E958D0 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = "";
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
OTHER_LDFLAGS = "";
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_INSTALL_OBJC_HEADER = NO;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC60D2CA1F04A00E958D0 /* Debug */,
5CCBC60E2CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */ = {
isa = XCConfigurationList;
buildConfigurations = (
5CCBC6102CA1F04A00E958D0 /* Debug */,
5CCBC6112CA1F04A00E958D0 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCRemoteSwiftPackageReference section */
5CAF3DD62CA485740029CD2B /* XCRemoteSwiftPackageReference "llama-stack-client-swift" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/meta-llama/llama-stack-client-swift";
requirement = {
branch = main;
kind = branch;
};
};
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch";
requirement = {
branch = latest;
kind = branch;
};
};
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/stencilproject/Stencil";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.15.1;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
5CADC7192CA471CC007662D2 /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
productName = LlamaStackClient;
};
5CAF3DD72CA485740029CD2B /* LlamaStackClient */ = {
isa = XCSwiftPackageProductDependency;
package = 5CAF3DD62CA485740029CD2B /* XCRemoteSwiftPackageReference "llama-stack-client-swift" */;
productName = LlamaStackClient;
};
5CCBC6742CA1F45800E958D0 /* executorch_debug */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */;
productName = executorch_debug;
};
5CCBC6922CA1F7D000E958D0 /* Stencil */ = {
isa = XCSwiftPackageProductDependency;
package = 5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */;
productName = Stencil;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = 5CCBC5FF2CA1F04A00E958D0 /* Project object */;
}

View file

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>

View file

@ -0,0 +1,9 @@
#import <Foundation/Foundation.h>
//! Project version number for LocalInference.
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
//! Project version string for LocalInference.
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>

View file

@ -0,0 +1,189 @@
import Foundation
import LLaMARunner
import LlamaStackClient
class RunnerHolder: ObservableObject {
var runner: Runner?
}
public class LocalInference: Inference {
private var runnerHolder = RunnerHolder()
private let runnerQueue: DispatchQueue
public init (queue: DispatchQueue) {
runnerQueue = queue
}
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
runnerHolder.runner = runnerHolder.runner ?? Runner(
modelPath: modelPath,
tokenizerPath: tokenizerPath
)
runnerQueue.async {
let runner = self.runnerHolder.runner
do {
try runner!.load()
completion(.success(()))
} catch let loadError {
print("error: " + loadError.localizedDescription)
completion(.failure(loadError))
}
}
}
public func stop() {
runnerHolder.runner?.stop()
}
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
let workItem = DispatchWorkItem {
do {
var tokens: [String] = []
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
var stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload? = nil
var buffer = ""
var ipython = false
var echoDropped = false
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
buffer += token
// HACK: Workaround until LlamaRunner exposes echo param
if (!echoDropped) {
if (buffer.hasPrefix(prompt)) {
buffer = String(buffer.dropFirst(prompt.count))
echoDropped = true
}
return
}
tokens.append(token)
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
ipython = true
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
event_type: .progress,
delta: .tool_call(Components.Schemas.ToolCallDelta(
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
tool_call: .case1(""),
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.started
)
)
)
)
)
if (buffer.starts(with: "<|python_tag|>")) {
buffer = String(buffer.dropFirst("<|python_tag|>".count))
}
}
// TODO: Non-streaming lobprobs
var text = ""
if token == "<|eot_id|>" {
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_turn
} else if token == "<|eom_id|>" {
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
} else {
text = token
}
var delta: Components.Schemas.ContentDelta
if ipython {
delta = .tool_call(Components.Schemas.ToolCallDelta(
_type: .tool_call,
tool_call: .case1(text),
parse_status: .in_progress
))
} else {
delta = .text(Components.Schemas.TextDelta(
_type: Components.Schemas.TextDelta._typePayload.text,
text: text
)
)
}
if stopReason == nil {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
event_type: .progress,
delta: delta
)
)
)
}
}
if stopReason == nil {
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.out_of_tokens
}
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
// TODO: non-streaming support
let didParseToolCalls = message.tool_calls?.count ?? 0 > 0
if ipython && !didParseToolCalls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
event_type: .progress,
delta: .tool_call(Components.Schemas.ToolCallDelta(
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
tool_call: .case1(""),
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.failed
)
)
)
// TODO: stopReason
)
)
}
for toolCall in message.tool_calls! {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
event_type: .progress,
delta: .tool_call(Components.Schemas.ToolCallDelta(
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall),
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.succeeded
)
)
)
// TODO: stopReason
)
)
}
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
event_type: .complete,
delta: .text(Components.Schemas.TextDelta(
_type: Components.Schemas.TextDelta._typePayload.text,
text: ""
)
)
)
// TODO: stopReason
)
)
}
catch (let error) {
print("Inference error: " + error.localizedDescription)
}
}
runnerQueue.async(execute: workItem)
}
}
}

View file

@ -0,0 +1,238 @@
import Foundation
import LlamaStackClient
func encodeHeader(role: String) -> String {
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
}
func encodeDialogPrompt(messages: [Components.Schemas.Message]) -> String {
var prompt = ""
prompt.append("<|begin_of_text|>")
for message in messages {
let msg = encodeMessage(message: message)
prompt += msg
}
prompt.append(encodeHeader(role: "assistant"))
return prompt
}
func getRole(message: Components.Schemas.Message) -> String {
switch (message) {
case .user(let m):
return m.role.rawValue
case .system(let m):
return m.role.rawValue
case .tool(let m):
return m.role.rawValue
case .assistant(let m):
return m.role.rawValue
}
}
func encodeMessage(message: Components.Schemas.Message) -> String {
var prompt = encodeHeader(role: getRole(message: message))
switch (message) {
case .assistant(let m):
if (m.tool_calls?.count ?? 0 > 0) {
prompt += "<|python_tag|>"
}
default:0
break
}
func _processContent(_ content: Any) -> String {
func _process(_ c: Any) {
if let str = c as? String {
prompt += str
}
}
if let str = content as? String {
_process(str)
} else if let list = content as? [Any] {
for c in list {
_process(c)
}
}
return ""
}
switch (message) {
case .user(let m):
prompt += _processContent(m.content)
case .system(let m):
prompt += _processContent(m.content)
case .tool(let m):
prompt += _processContent(m.content)
case .assistant(let m):
prompt += _processContent(m.content)
}
var eom = false
switch (message) {
case .user(let m):
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .InterleavedContentItem(let c):
prompt += _processContent(c)
case .case3(let c):
prompt += _processContent(c)
}
case .assistant(let m):
// TODO: Support encoding past tool call history
// for t in m.tool_calls {
// _processContent(t.)
//}
eom = m.stop_reason == Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
case .system(_):
break
case .tool(_):
break
}
if (eom) {
prompt += "<|eom_id|>"
} else {
prompt += "<|eot_id|>"
}
return prompt
}
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.Message] {
var existingMessages = request.messages
var existingSystemMessage: Components.Schemas.Message?
// TODO: Existing system message
var messages: [Components.Schemas.Message] = []
let defaultGen = SystemDefaultGenerator()
let defaultTemplate = defaultGen.gen()
var sysContent = ""
// TODO: Built-in tools
sysContent += try defaultTemplate.render()
messages.append(.system(Components.Schemas.SystemMessage(
role: .system,
content: .case1(sysContent)
))
)
if request.tools?.isEmpty == false {
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
let toolGen = FunctionTagCustomToolGenerator()
let toolTemplate = try toolGen.gen(customTools: request.tools!)
let tools = try toolTemplate.render()
messages.append(.user(Components.Schemas.UserMessage(
role: .user,
content: .case1(tools))
))
}
messages.append(contentsOf: existingMessages)
return messages
}
struct FunctionCall {
let name: String
let params: [String: Any]
}
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
guard input.hasPrefix("[") && input.hasSuffix("]") else {
return []
}
do {
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
var result: [Components.Schemas.ToolCall] = []
for call in calls {
guard let nameEndIndex = call.firstIndex(of: "("),
let paramsStartIndex = call.firstIndex(of: "{"),
let paramsEndIndex = call.lastIndex(of: "}") else {
return []
}
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
guard let data = paramsString.data(using: .utf8),
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
return []
}
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
for (param_name, param) in params {
switch (param) {
case let value as String:
props[param_name] = .case1(value)
case let value as Int:
props[param_name] = .case2(value)
case let value as Double:
props[param_name] = .case3(value)
case let value as Bool:
props[param_name] = .case4(value)
default:
return []
}
}
result.append(
Components.Schemas.ToolCall(
call_id: UUID().uuidString,
tool_name: .case2(name), // custom_tool
arguments: .init(additionalProperties: props)
)
)
}
return result.isEmpty ? [] : result
} catch {
return []
}
}
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload) -> Components.Schemas.CompletionMessage {
var content = tokens
let roles = ["user", "system", "assistant"]
for role in roles {
let headerStr = encodeHeader(role: role)
if content.hasPrefix(headerStr) {
content = String(content.dropFirst(encodeHeader(role: role).count))
}
}
if content.hasPrefix("<|python_tag|>") {
content = String(content.dropFirst("<|python_tag|>".count))
}
if content.hasSuffix("<|eot_id|>") {
content = String(content.dropLast("<|eot_id|>".count))
} else {
content = String(content.dropLast("<|eom_id|>".count))
}
return Components.Schemas.CompletionMessage(
role: .assistant,
content: .case1(content),
stop_reason: stopReason,
tool_calls: maybeExtractCustomToolCalls(input: content)
)
}

View file

@ -0,0 +1,12 @@
import Foundation
import Stencil
public struct PromptTemplate {
let template: String
let data: [String: Any]
public func render() throws -> String {
let template = Template(templateString: self.template)
return try template.render(self.data)
}
}

View file

@ -0,0 +1,89 @@
import Foundation
import LlamaStackClient
func convertToNativeSwiftType(_ value: Any) -> Any {
switch value {
case let number as NSNumber:
if CFGetTypeID(number) == CFBooleanGetTypeID() {
return number.boolValue
}
if floor(number.doubleValue) == number.doubleValue {
return number.intValue
}
return number.doubleValue
case let string as String:
return string
case let array as [Any]:
return array.map(convertToNativeSwiftType)
case let dict as [String: Any]:
return dict.mapValues(convertToNativeSwiftType)
case is NSNull:
return NSNull()
default:
return value
}
}
public class SystemDefaultGenerator {
public init() {}
public func gen() -> PromptTemplate {
let templateStr = """
Cutting Knowledge Date: December 2023
Today Date: {{ today }}
"""
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "dd MMMM yyyy"
return PromptTemplate(
template: templateStr,
data: ["today": dateFormatter.string(from: Date())]
)
}
}
public class FunctionTagCustomToolGenerator {
public init() {}
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
// TODO: required params
// TODO: {{#unless @last}},{{/unless}}
let templateStr = """
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{% for t in custom_tools %}
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"input_schema": { {{t.input_schema}} }
}
{{/let}}
{% endfor -%}
]
"""
let encoder = JSONEncoder()
return PromptTemplate(
template: templateStr,
data: ["custom_tools": try customTools.map {
let data = try encoder.encode($0)
let obj = try JSONSerialization.jsonObject(with: data)
return convertToNativeSwiftType(obj)
}]
)
}
}

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
def evacuate_model_from_device(model, device: str):
"""Safely clear a model from memory and free device resources.
This function handles the proper cleanup of a model by:
1. Moving the model to CPU if it's on a non-CPU device
2. Deleting the model object to free memory
3. Running garbage collection
4. Clearing CUDA cache if the model was on a CUDA device
Args:
model: The PyTorch model to clear
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
Note:
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
- For MPS devices, only moves the model to CPU (no cache clearing available)
- For CPU devices, only deletes the model object and runs garbage collection
"""
if device != "cpu":
model.to("cpu")
del model
gc.collect()
if device == "cuda":
# we need to import such that this is only imported when the method is called
import torch
torch.cuda.empty_cache()

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
StringType,
)
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
)
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
"instruct": [
{
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
ColumnName.expected_answer.value: StringType(),
}
],
"dialog": [
{
ColumnName.dialog.value: DialogType(),
}
],
}

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import HuggingFacePostTrainingConfig
# post_training api and the huggingface provider is still experimental and under heavy development
async def get_provider_impl(
config: HuggingFacePostTrainingConfig,
deps: dict[Api, Any],
):
from .post_training import HuggingFacePostTrainingImpl
impl = HuggingFacePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel
class HuggingFacePostTrainingConfig(BaseModel):
# Device to run training on (cuda, cpu, mps)
device: str = "cuda"
# Distributed training backend if using multiple devices
# fsdp: Fully Sharded Data Parallel
# deepspeed: DeepSpeed ZeRO optimization
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
# Format for saving model checkpoints
# full_state: Save complete model state
# huggingface: Save in HuggingFace format (recommended for compatibility)
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
# Template for formatting chat inputs and outputs
# Used to structure the conversation format for training
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
# Model-specific configuration parameters
# trust_remote_code: Allow execution of custom model code
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
model_specific_config: dict = {
"trust_remote_code": True,
"attn_implementation": "sdpa",
}
# Maximum sequence length for training
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
# Longer sequences may cause memory issues on MPS devices
max_seq_length: int = 2048
# Enable gradient checkpointing to reduce memory usage
# Trades computation for memory by recomputing activations
gradient_checkpointing: bool = False
# Maximum number of checkpoints to keep
# Older checkpoints are deleted when this limit is reached
save_total_limit: int = 3
# Number of training steps between logging updates
logging_steps: int = 10
# Ratio of training steps used for learning rate warmup
# Helps stabilize early training
warmup_ratio: float = 0.1
# L2 regularization coefficient
# Helps prevent overfitting
weight_decay: float = 0.01
# Number of worker processes for data loading
# Higher values can improve data loading speed but increase memory usage
dataloader_num_workers: int = 4
# Whether to pin memory in data loader
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
# DPO-specific parameters
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"checkpoint_format": "huggingface",
"distributed_backend": None,
"device": "cpu",
"dpo_output_dir": __distro_dir__ + "/dpo_output",
}

View file

@ -0,0 +1,208 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
_JOB_TYPE_DPO_TRAINING = "dpo-training"
class HuggingFacePostTrainingImpl:
def __init__(
self,
config: HuggingFacePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
resources_allocated, checkpoints = await recipe.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
HFDPOAlignmentSingleDevice,
)
on_log_message_cb("Starting HF DPO alignment")
recipe = HFDPOAlignmentSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
resources_allocated, checkpoints = await recipe.train(
model=finetuned_model,
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
job_uuid=job_uuid,
dpo_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
else:
on_log_message_cb("Warning: No checkpoints were saved during DPO training")
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF DPO alignment completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,519 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import json
import multiprocessing
from pathlib import Path
from typing import Any
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = get_logger(name=__name__, category="post_training")
class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
"""Validate that the dataset has the required fields."""
required_fields = ["input_query", "expected_answer", "chat_completion_input"]
return all(field in row for row in rows for field in required_fields)
def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in instruct format."""
if "chat_completion_input" in row and "expected_answer" in row:
try:
messages = json.loads(row["chat_completion_input"])
if not isinstance(messages, list) or len(messages) != 1:
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
return None, None
if "content" not in messages[0]:
logger.warning(f"Message missing content: {messages[0]}")
return None, None
return messages[0]["content"], row["expected_answer"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
return None, None
return None, None
def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in dialog format."""
if "dialog" in row:
try:
dialog = json.loads(row["dialog"])
if not isinstance(dialog, list) or len(dialog) < 2:
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
return None, None
if dialog[0].get("role") != "user":
logger.warning(f"First message must be from user: {dialog[0]}")
return None, None
if not any(msg.get("role") == "assistant" for msg in dialog):
logger.warning("Dialog must have at least one assistant message")
return None, None
# Convert to human/gpt format
role_map = {"user": "human", "assistant": "gpt"}
conversations = []
for msg in dialog:
if "role" not in msg or "content" not in msg:
logger.warning(f"Message missing role or content: {msg}")
continue
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
# Format as a single conversation
return conversations[0]["value"], conversations[1]["value"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse dialog: {row['dialog']}")
return None, None
return None, None
def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row using fallback formats."""
if "input" in row and "output" in row:
return row["input"], row["output"]
elif "prompt" in row and "completion" in row:
return row["prompt"], row["completion"]
elif "question" in row and "answer" in row:
return row["question"], row["answer"]
return None, None
def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format input and output text based on model requirements."""
if hasattr(provider_config, "chat_template"):
return provider_config.chat_template.format(input=input_text, output=output_text)
return f"{input_text}\n{output_text}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset."""
formatted_rows = []
for row in rows:
input_text = None
output_text = None
# Process based on format
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
if config.data_config.data_format.value == "instruct":
input_text, output_text = self._process_instruct_format(row)
elif config.data_config.data_format.value == "dialog":
input_text, output_text = self._process_dialog_format(row)
else:
input_text, output_text = self._process_fallback_format(row)
if input_text and output_text:
formatted_text = self._format_text(input_text, output_text, provider_config)
formatted_rows.append({"text": formatted_text})
if not formatted_rows:
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
raise ValueError(
f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}"
)
return Dataset.from_list(formatted_rows)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer."""
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding=True,
truncation=True,
max_length=provider_config.max_seq_length,
return_tensors=None,
)
return ds.map(
tokenize_function,
batched=True,
remove_columns=ds.column_names,
)
def _run_training_sync(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Synchronous wrapper for running training process.
This method serves as a bridge between the multiprocessing Process and the async training function.
It creates a new event loop to run the async training process.
Args:
model: The model identifier to load
dataset_id: ID of the dataset to use for training
provider_config: Configuration specific to the HuggingFace provider
peft_config: Optional LoRA configuration
config: General training configuration
output_dir_path: Optional path to save the model
"""
import asyncio
logger.info("Starting training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
provider_config=provider_config,
peft_config=peft_config,
config=config,
output_dir_path=output_dir_path,
)
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for training.
Args:
model: The model identifier to load
config: Training configuration
provider_config: Provider-specific configuration
Returns:
tuple: (train_dataset, eval_dataset, tokenizer)
"""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set pad token to eos token if not present
# This is common for models that don't have a dedicated pad token
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Set padding side to right for causal language modeling
# This ensures that padding tokens don't interfere with the model's ability
# to predict the next token in the sequence
tokenizer.padding_side = "right"
# Set truncation side to right to keep the beginning of the sequence
# This is important for maintaining context and instruction format
tokenizer.truncation_side = "right"
# Set model max length to match provider config
# This ensures consistent sequence lengths across the training process
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def setup_training_args(
self,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
device: torch.device,
output_dir_path: Path | None,
steps_per_epoch: int,
) -> SFTConfig:
"""Setup training arguments.
Args:
config: Training configuration
provider_config: Provider-specific configuration
device: The device to train on
output_dir_path: Optional path to save the model
steps_per_epoch: Number of steps per epoch
Returns:
Configured SFTConfig object
"""
logger.info("Configuring training arguments")
lr = 2e-5
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
return SFTConfig(
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
fp16=device.type == "cuda",
bf16=False, # Causes CPU issues.
eval_strategy=eval_strategy,
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
report_to="none",
max_length=provider_config.max_seq_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
dataset_text_field="text",
packing=False,
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=step_info["logging_steps"],
)
def save_model(
self,
model_obj: AutoModelForCausalLM,
trainer: SFTTrainer,
peft_config: LoraConfig | None,
output_dir_path: Path,
) -> None:
"""Save the trained model.
Args:
model_obj: The model to save
trainer: The trainer instance
peft_config: Optional LoRA configuration
output_dir_path: Path to save the model
"""
logger.info("Saving final model")
model_obj.config.use_cache = True
if peft_config:
logger.info("Merging LoRA weights with base model")
model_obj = trainer.model.merge_and_unload()
else:
model_obj = trainer.model
save_path = output_dir_path / "merged_model"
logger.info(f"Saving model to {save_path}")
model_obj.save_pretrained(save_path)
async def _run_training(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Run the training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
training_args = self.setup_training_args(
config_obj,
provider_config_obj,
device,
output_dir_path,
steps_per_epoch,
)
# Load model
model_obj = load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
trainer = SFTTrainer(
model=model_obj,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_args,
)
try:
# Train
logger.info("Starting training")
trainer.train()
logger.info("Training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
self.save_model(model_obj, trainer, peft_config, output_dir_path)
finally:
# Clean up resources
logger.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
del trainer
gc.collect()
logger.info("Cleanup completed")
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
lora_config: LoraFinetuningConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats
memory_stats = {
"initial": get_memory_stats(device),
"after_training": None,
"final": None,
}
# Configure LoRA
peft_config = None
if lora_config:
peft_config = LoraConfig(
lora_alpha=lora_config.alpha,
lora_dropout=0.1,
r=lora_config.rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_config.lora_attn_modules,
)
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting training in separate process")
try:
# Setup multiprocessing for device
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,
kwargs={
"model": model,
"provider_config": provider_config.model_dump(),
"peft_config": peft_config,
"config": config.model_dump(),
"output_dir_path": output_dir_path,
},
)
process.start()
# Monitor the process
while process.is_alive():
process.join(timeout=1) # Check every second
if not process.is_alive():
break
# Get the return code
if process.exitcode != 0:
raise RuntimeError(f"Training failed with exit code {process.exitcode}")
memory_stats["after_training"] = get_memory_stats(device)
checkpoints = []
if output_dir_path:
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model")
return memory_stats, checkpoints if checkpoints else None
finally:
memory_stats["final"] = get_memory_stats(device)
gc.collect()

View file

@ -0,0 +1,485 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import multiprocessing
from pathlib import Path
from typing import Any
import torch
from datasets import Dataset
from transformers import (
AutoTokenizer,
)
from trl import DPOConfig, DPOTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = get_logger(name=__name__, category="post_training")
class HFDPOAlignmentSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> None:
"""Validate that the dataset has the required fields for DPO training."""
required_fields = ["prompt", "chosen", "rejected"]
if not rows:
logger.warning("Dataset is empty")
raise ValueError("Dataset is empty")
for i, row in enumerate(rows):
if not isinstance(row, dict):
logger.warning(f"Row {i} is not a dictionary")
raise ValueError(f"Row {i} is not a dictionary")
for field in required_fields:
if field not in row:
logger.warning(f"Row {i} missing required DPO field: {field}")
raise ValueError(f"Row {i} missing required DPO field: {field}")
# Handle both string and list formats
if field == "prompt":
# Prompt should be a string
if not isinstance(row[field], str):
logger.warning(f"Row {i} field '{field}' is not a string")
raise ValueError(f"Row {i} field '{field}' is not a string")
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
raise ValueError(f"Row {i} field '{field}' is empty")
else:
# chosen/rejected can be either strings or lists of messages
if isinstance(row[field], str):
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
raise ValueError(f"Row {i} field '{field}' is empty")
elif isinstance(row[field], list):
if not row[field]:
logger.warning(f"Row {i} field '{field}' is empty list")
raise ValueError(f"Row {i} field '{field}' is empty list")
else:
logger.warning(f"Row {i} field '{field}' is neither string nor list")
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
"""Process a row in DPO format, handling both string and conversation list formats."""
if all(field in row for field in ["prompt", "chosen", "rejected"]):
prompt = row["prompt"]
# Handle chosen field - convert list to string if needed
if isinstance(row["chosen"], list):
# For conversation format, concatenate messages
chosen = "\n".join(
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["chosen"]]
)
else:
chosen = row["chosen"]
# Handle rejected field - convert list to string if needed
if isinstance(row["rejected"], list):
# For conversation format, concatenate messages
rejected = "\n".join(
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["rejected"]]
)
else:
rejected = row["rejected"]
return prompt, chosen, rejected
return None, None, None
def _format_text_for_dpo(self, prompt: str, response: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format prompt and response text based on model requirements."""
if hasattr(provider_config, "chat_template") and provider_config.chat_template:
# Use the chat template, supporting both {prompt}/{response} and {input}/{output}
template = provider_config.chat_template
# Try prompt/response first (DPO style)
if "{prompt}" in template and "{response}" in template:
return template.format(prompt=prompt, response=response)
# Fall back to input/output (SFT style)
elif "{input}" in template and "{output}" in template:
return template.format(input=prompt, output=response)
else:
# If template doesn't have expected placeholders, use default
return f"{prompt}\n{response}"
return f"{prompt}\n{response}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset for DPO."""
dpo_examples = []
for row in rows:
prompt, chosen, rejected = self._process_dpo_format(row)
if prompt and chosen and rejected:
# Format the texts
chosen_formatted = self._format_text_for_dpo(prompt, chosen, provider_config)
rejected_formatted = self._format_text_for_dpo(prompt, rejected, provider_config)
dpo_examples.append(
{
"prompt": prompt,
"chosen": chosen_formatted,
"rejected": rejected_formatted,
}
)
if not dpo_examples:
raise ValueError("No valid preference examples found in dataset")
logger.info(f"Created DPO dataset with {len(dpo_examples)} preference pairs")
return Dataset.from_list(dpo_examples)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer for DPO."""
# DPOTrainer expects raw text, so we don't tokenize here
# Just return the dataset as is
return ds
def _run_training_sync(
self,
model: str,
provider_config: dict[str, Any],
dpo_config: dict[str, Any],
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Synchronous wrapper for running DPO training process."""
import asyncio
logger.info("Starting DPO training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
provider_config=provider_config,
dpo_config=dpo_config,
config=config,
output_dir_path=output_dir_path,
)
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for DPO training."""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for DPO training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
self.validate_dataset_format(rows)
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set pad token to eos token if not present
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Set padding side to left for DPO
tokenizer.padding_side = "left"
# Set truncation side to right to keep the beginning of the sequence
tokenizer.truncation_side = "right"
# Set model max length to match provider config
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully for DPO")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset for DPO")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def setup_training_args(
self,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
dpo_config: DPOAlignmentConfig,
device: torch.device,
output_dir_path: Path | None,
steps_per_epoch: int,
) -> DPOConfig:
"""Setup DPO training arguments."""
logger.info("Configuring DPO training arguments")
lr = 5e-7 # Lower learning rate for DPO
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
logger.info("DPO training configuration:")
logger.info(f"- DPO beta: {dpo_config.beta}")
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
# Calculate max prompt length as half of max sequence length
max_prompt_length = provider_config.max_seq_length // 2
return DPOConfig(
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
fp16=device.type == "cuda",
bf16=False, # Causes CPU issues.
eval_strategy=eval_strategy,
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
report_to="none",
max_length=provider_config.max_seq_length,
max_prompt_length=max_prompt_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=step_info["logging_steps"],
save_total_limit=provider_config.save_total_limit,
# DPO specific parameters
beta=dpo_config.beta,
loss_type=provider_config.dpo_loss_type,
)
def save_model(
self,
trainer: DPOTrainer,
output_dir_path: Path,
) -> None:
"""Save the trained DPO model."""
logger.info("Saving final DPO model")
save_path = output_dir_path / "dpo_model"
logger.info(f"Saving model to {save_path}")
# Save model and tokenizer
trainer.save_model(str(save_path))
async def _run_training(
self,
model: str,
provider_config: dict[str, Any],
dpo_config: dict[str, Any],
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Run the DPO training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
dpo_config_obj = DPOAlignmentConfig(**dpo_config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
training_args = self.setup_training_args(
config_obj,
provider_config_obj,
dpo_config_obj,
device,
output_dir_path,
steps_per_epoch,
)
# Load model and reference model
model_obj = load_model(model, device, provider_config_obj)
ref_model = None
if provider_config_obj.use_reference_model:
logger.info("Loading separate reference model for DPO")
ref_model = load_model(model, device, provider_config_obj)
else:
logger.info("Using shared reference model for DPO")
# Initialize DPO trainer
logger.info("Initializing DPOTrainer")
trainer = DPOTrainer(
model=model_obj,
ref_model=ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
try:
# Train
logger.info("Starting DPO training")
trainer.train()
logger.info("DPO training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
logger.info(f"Saving model to output directory: {output_dir_path}")
self.save_model(trainer, output_dir_path)
logger.info("Model save completed")
finally:
# Clean up resources
logger.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
if ref_model:
evacuate_model_from_device(ref_model, device.type)
del trainer
del ref_model
gc.collect()
logger.info("Cleanup completed")
logger.info("DPO training process finishing successfully")
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
dpo_config: DPOAlignmentConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's DPOTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats
memory_stats = {
"initial": get_memory_stats(device),
"after_training": None,
"final": None,
}
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting DPO training in separate process")
try:
# Setup multiprocessing for device
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,
kwargs={
"model": model,
"provider_config": provider_config.model_dump(),
"dpo_config": dpo_config.model_dump(),
"config": config.model_dump(),
"output_dir_path": output_dir_path,
},
)
process.start()
# Monitor the process
while process.is_alive():
process.join(timeout=1) # Check every second
if not process.is_alive():
break
# Get the return code
if process.exitcode != 0:
raise RuntimeError(f"DPO training failed with exit code {process.exitcode}")
memory_stats["after_training"] = get_memory_stats(device)
checkpoints = []
if output_dir_path:
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model")
return memory_stats, checkpoints if checkpoints else None
finally:
memory_stats["final"] = get_memory_stats(device)
gc.collect()

View file

@ -0,0 +1,269 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import psutil
import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig
logger = get_logger(name=__name__, category="post_training")
def setup_environment():
"""Setup common environment variables for training."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
def bytes_to_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": bytes_to_gb(psutil.virtual_memory().total),
"available": bytes_to_gb(psutil.virtual_memory().available),
"used": bytes_to_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": bytes_to_gb(process.memory_info().rss),
"process_vms": bytes_to_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def load_model(
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
"""Split dataset into train and validation sets.
Args:
ds: Dataset to split
Returns:
tuple: (train_dataset, eval_dataset)
"""
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
"""Calculate training steps and logging configuration.
Args:
steps_per_epoch: Number of training steps per epoch
config: Training configuration
Returns:
dict: Dictionary with calculated step values
"""
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
"""Get save and evaluation strategy based on output directory.
Args:
output_dir_path: Optional path to save the model
Returns:
tuple: (save_strategy, eval_strategy)
"""
if output_dir_path:
logger.info(f"Will save checkpoints to {output_dir_path}")
return "epoch", "epoch"
return "no", "no"
def create_checkpoints(
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
) -> list[Checkpoint]:
"""Create checkpoint objects from training output.
Args:
output_dir_path: Path to the training output directory
job_uuid: Unique identifier for the training job
model: Model identifier
config: Training configuration
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
Returns:
List of Checkpoint objects
"""
checkpoints = []
# Add checkpoint directories
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add final model
final_model_path = output_dir_path / final_model_name
if final_model_path.exists():
training_type = "sft" if final_model_name == "merged_model" else "dpo"
checkpoint = Checkpoint(
identifier=f"{model}-{training_type}-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(final_model_path),
)
checkpoints.append(checkpoint)
return checkpoints

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import TorchtunePostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: dict[Api, Any],
):
from .post_training import TorchtunePostTrainingImpl
impl = TorchtunePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import shutil
from pathlib import Path
from typing import Any
import torch
from safetensors.torch import save_file
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
REPO_ID_FNAME,
SUFFIXES_TO_NOT_COPY,
ModelType,
copy_files,
safe_torch_load,
)
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: list[str],
output_dir: str,
model_type: str,
):
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
def load_checkpoint(self) -> dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")
return state_dict
def save_checkpoint(
self,
state_dict: dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str | None = None,
) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta" or checkpoint_format is None:
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
self._save_hf_format_checkpoint(model_file_path, state_dict)
else:
raise ValueError(f"Unsupported checkpoint format: {format}")
return str(model_file_path)
def _save_meta_format_checkpoint(
self,
model_file_path: Path,
state_dict: dict[str, Any],
adapter_only: bool = False,
) -> None:
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "params.json"),
)
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "tokenizer.model"),
)
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
def _save_hf_format_checkpoint(
self,
model_file_path: Path,
state_dict: dict[str, Any],
) -> None:
# the config.json file contains model params needed for state dict conversion
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
# This json file is produced and saved in the download step.
# contents are {"repo_id": "some_model/some_model_version"}
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path) as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")
if training.ADAPTER_KEY in state_dict:
# TODO: saving it "as is" is a requirement because, if we only save with
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
# convert_weights.peft_to_tune. The .pt format is not needed, but
# it is an easy way to distinguish the adapters. Ideally we should save only one.
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict[training.ADAPTER_KEY], output_path)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=config["num_attention_heads"],
num_kv_heads=config["num_key_value_heads"],
dim=config["hidden_size"],
head_dim=config.get("head_dim", None),
)
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path = output_path.with_suffix(".safetensors")
save_file(
state_dict[training.ADAPTER_KEY],
output_path,
metadata={"format": "pt"},
)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
else:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
if training.ADAPTER_CONFIG in state_dict:
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
# So its easy to run inference with the model using this epoch's checkpoint
copy_files(
self._checkpoint_dir.parent,
model_file_path,
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
)

View file

@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
import torch
from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
class ModelConfig(BaseModel):
model_definition: BuildLoraModelCallable
tokenizer_type: BuildTokenizerCallable
checkpoint_type: str
MODEL_CONFIGS: dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2",
),
"Llama3.1-8B-Instruct": ModelConfig(
model_definition=lora_llama3_1_8b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3",
),
}
DATA_FORMATS: dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
raise ValueError(f"Model {model_id} is not supported.")
return model
async def get_model_definition(
model_id: str,
) -> BuildLoraModelCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "model_definition"):
raise ValueError(f"Model {model_id} does not have model definition.")
return model_config.model_definition
async def get_tokenizer_type(
model_id: str,
) -> BuildTokenizerCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "tokenizer_type"):
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
return model_config.tokenizer_type
async def get_checkpointer_model_type(
model_id: str,
) -> str:
"""
checkpointer model type is used in checkpointer for some special treatment on some specific model types
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
"""
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "checkpoint_type"):
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
return model_config.checkpoint_type
async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
torch_seed: int | None = None
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"checkpoint_format": "meta",
}

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
from collections.abc import Mapping
from typing import Any
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any],
) -> Mapping[str, Any]:
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
"Invalid input row"
)
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"
input = input_message["content"]
output = sample[ColumnName.expected_answer.value]
return {
"input": input,
"output": output,
}
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"}
dialog = json.loads(sample[ColumnName.dialog.value])
assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = []
conversations = []
for message in dialog:
assert "role" in message and "content" in message, "role and content must in message"
roles.append(message["role"])
conversations.append({"from": role_map[message["role"]], "value": message["content"]})
assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant"
return {"conversations": conversations}

View file

@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Mapping
from typing import Any
import numpy as np
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
llama_stack_chat_to_torchtune_chat,
llama_stack_instruct_to_torchtune_instruct,
)
class SFTDataset(Dataset):
def __init__(
self,
rows: list[dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
dataset_type: str,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform
self._dataset_type = dataset_type
def __len__(self):
return len(self._rows)
def __getitem__(self, index: int) -> dict[str, Any]:
sample = self._rows[index]
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]:
if self._dataset_type == "instruct":
sample = llama_stack_instruct_to_torchtune_instruct(sample)
elif self._dataset_type == "dialog":
sample = llama_stack_chat_to_torchtune_chat(sample)
else:
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
f"model_transform returned the following keys: {keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
tokenized_dict["mask"],
CROSS_ENTROPY_IGNORE_IDX,
tokenized_dict["tokens"],
)
)
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
return tokenized_dict

View file

@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl:
def __init__(
self,
config: TorchtunePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None,
algorithm_config: AlgorithmConfig | None,
) -> PostTrainingJob:
if isinstance(algorithm_config, LoraFinetuningConfig):
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
self.datasets_api,
)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
else:
raise NotImplementedError()
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError()
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,588 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import time
from datetime import UTC, datetime
from functools import partial
from pathlib import Path
from typing import Any
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
log = get_logger(name=__name__, category="post_training")
class LoraFinetuningSingleDevice:
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile
# - activation offloading
# Resume from checkpoint hasn't been supported yet
# Validation hasn't been supported yet
# Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs
_checkpointer: TorchtuneCheckpointer
def __init__(
self,
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None,
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
self.model_id = model
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
hf_repo = model.huggingface_repo or f"meta-llama/{model.descriptor()}"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download the model using `huggingface-cli download {hf_repo} --local-dir ~/.llama/{model.descriptor()}`"
)
return str(checkpoint_dir)
if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = checkpoint_dir
else:
model_obj = resolve_model(self.model_id)
if model_obj is None:
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
self.checkpoint_dir = model_checkpoint_dir(model_obj)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._train_on_input = training_config.data_config.train_on_input
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = False
self._enable_activation_offloading = False
if training_config.efficiency_config:
if training_config.efficiency_config.enable_activation_checkpointing:
self._enable_activation_checkpointing = (
training_config.efficiency_config.enable_activation_checkpointing
)
if training_config.efficiency_config.enable_activation_offloading:
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
try:
# List all files in the given directory
files = os.listdir(checkpoint_dir)
# Filter files that end with .pth
pth_files = [file for file in files if file.endswith(".pth")]
return pth_files
except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = TorchtuneCheckpointer(
model_id=self.model_id,
training_algorithm="sft",
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
model_type=await utils.get_checkpointer_model_type(self.model_id),
)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict
async def setup(self) -> None:
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
enable_activation_checkpointing=self._enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=None,
)
log.info(f"Model is initialized with precision {self._dtype}.")
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
self._training_sampler, self._training_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.dataset_id,
tokenizer=self._tokenizer,
shuffle=self._shuffle,
batch_size=self._batch_size,
)
if self.training_config.data_config.validation_dataset_id:
_, self._validation_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.validation_dataset_id,
tokenizer=self._tokenizer,
shuffle=False,
batch_size=self._batch_size,
)
log.info("Dataset and Sampler are initialized.")
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = await self._setup_lr_scheduler(
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
if self._device.type == "cpu":
return
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
async def _setup_model(
self,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
base_model_state_dict: dict[str, Any],
lora_weights_state_dict: dict[str, Any] | None = None,
) -> nn.Module:
self._lora_rank = self.algorithm_config.rank
self._lora_alpha = self.algorithm_config.alpha
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
self._use_dora = self.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device:
model_type = await utils.get_model_definition(self.model_id)
model = model_type(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
lora_rank=self._lora_rank,
lora_alpha=self._lora_alpha,
quantize_base=False,
use_dora=self._use_dora,
)
self.adapter_params = get_adapter_params(model)
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
if self._is_dora:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
else:
lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
base_missing=base_missing,
base_unexpected=base_unexpected,
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)
# Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
self._log_memory_stats()
return model
async def _setup_tokenizer(
self,
) -> Llama3Tokenizer:
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
return tokenizer_type(path=tokenizer_path)
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
optimizer = torch.optim.AdamW(
params=self._model.parameters(),
lr=optimizer_config.lr,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.1,
)
return optimizer
async def _setup_data(
self,
dataset_id: str,
tokenizer: Llama3Tokenizer,
shuffle: bool,
batch_size: int,
) -> tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
all_rows = await fetch_rows(dataset_id)
rows = all_rows.data
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=data_transform(train_on_input=self._train_on_input),
model_transform=tokenizer,
dataset_type=self._data_format.value,
)
sampler = DistributedSampler(
ds,
num_replicas=1,
rank=0,
shuffle=shuffle,
seed=0,
)
dataloader = DataLoader(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
),
)
return sampler, dataloader
async def _setup_lr_scheduler(
self,
num_warmup_steps: int,
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
lr_scheduler = get_cosine_schedule_with_warmup(
self._optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
return lr_scheduler
async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
# Construct the full state dict with LoRA weights merged into base LLM weights
# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
checkpoint_format=self._checkpoint_format,
)
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
# run model
with self.activations_handling_ctx:
logits = self._model(**batch)
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
return loss
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
"""
The core training loop.
"""
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss: float = 0.0
num_tokens = 0
# training artifacts
checkpoints = []
memory_stats: dict[str, Any] = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._training_dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
):
break
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss.detach().item()
current_loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss / num_tokens
pbar.update(1)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
self._log_memory_stats()
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0.0
num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1
log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(UTC),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
)
if self.training_config.data_config.validation_dataset_id:
validation_loss, perplexity = await self.validation()
training_metrics = PostTrainingMetric(
epoch=curr_epoch,
train_loss=loss_to_log,
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint.training_metrics = training_metrics
checkpoints.append(checkpoint)
# clean up the memory after training finishes
evacuate_model_from_device(self._model, self._device.type)
return (memory_stats, checkpoints)
async def validation(self) -> tuple[float, float]:
total_loss = 0.0
total_tokens = 0
log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader):
if idx == self.max_validation_steps:
break
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
loss = await self._loss_step(batch) * num_tokens
total_loss += loss
total_tokens += num_tokens
pbar.update(1)
pbar.set_description(f"validation step: {idx}")
mean_loss = total_loss / total_tokens
perplexity = torch.exp(torch.tensor(mean_loss))
return mean_loss, perplexity.item()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import CodeScannerConfig
log = get_logger(name=__name__, category="safety")
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"code-scanner",
"code-shield",
]
class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
raise ValueError(
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
)
async def run_shield(
self,
shield_id: str,
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
from codeshield.cs import CodeShield
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text)
violation = None
if result.is_insecure:
violation = SafetyViolation(
violation_level=(ViolationLevel.ERROR),
user_message="Sorry, I found security concerns in the code.",
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
)
return RunShieldResponse(violation=violation)
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
categories = {}
category_scores = {}
category_applied_input_types = {}
flagged = scan_result.is_insecure
user_message = None
metadata = {}
if scan_result.is_insecure:
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
categories = dict.fromkeys(pattern_ids, True)
category_scores = dict.fromkeys(pattern_ids, 1.0)
category_applied_input_types = {key: ["text"] for key in pattern_ids}
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
return ModerationObjectResults(
flagged=flagged,
categories=categories,
category_scores=category_scores,
category_applied_input_types=category_applied_input_types,
user_message=user_message,
metadata=metadata,
)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
if model is None:
raise ValueError("Code scanner moderation requires a model identifier.")
inputs = input if isinstance(input, list) else [input]
results = []
from codeshield.cs import CodeShield
for text_input in inputs:
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
try:
scan_result = await CodeShield.scan_code(text_input)
moderation_result = self.get_moderation_object_results(scan_result)
except Exception as e:
log.error(f"CodeShield.scan_code failed: {e}")
# create safe fallback response on scanner failure to avoid blocking legitimate requests
moderation_result = ModerationObjectResults(
flagged=False,
categories={},
category_scores={},
category_applied_input_types={},
user_message=None,
metadata={"scanner_error": str(e)},
)
results.append(moderation_result)
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class CodeScannerConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
impl = LlamaGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class LlamaGuardConfig(BaseModel):
excluded_categories: list[str] = []
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"excluded_categories": [],
}

View file

@ -0,0 +1,492 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import uuid
from string import Template
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import LlamaGuardConfig
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe"
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
CAT_SEX_CRIMES = "Sex Crimes"
CAT_CHILD_EXPLOITATION = "Child Exploitation"
CAT_DEFAMATION = "Defamation"
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
CAT_PRIVACY = "Privacy"
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
CAT_HATE = "Hate"
CAT_SELF_HARM = "Self-Harm"
CAT_SEXUAL_CONTENT = "Sexual Content"
CAT_ELECTIONS = "Elections"
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_VIOLENT_CRIMES: "S1",
CAT_NON_VIOLENT_CRIMES: "S2",
CAT_SEX_CRIMES: "S3",
CAT_CHILD_EXPLOITATION: "S4",
CAT_DEFAMATION: "S5",
CAT_SPECIALIZED_ADVICE: "S6",
CAT_PRIVACY: "S7",
CAT_INTELLECTUAL_PROPERTY: "S8",
CAT_INDISCRIMINATE_WEAPONS: "S9",
CAT_HATE: "S10",
CAT_SELF_HARM: "S11",
CAT_SEXUAL_CONTENT: "S12",
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES,
CAT_SEX_CRIMES,
CAT_CHILD_EXPLOITATION,
CAT_DEFAMATION,
CAT_SPECIALIZED_ADVICE,
CAT_PRIVACY,
CAT_INTELLECTUAL_PROPERTY,
CAT_INDISCRIMINATE_WEAPONS,
CAT_HATE,
CAT_SELF_HARM,
CAT_SEXUAL_CONTENT,
CAT_ELECTIONS,
]
# accept both CoreModelId and huggingface repo id
LLAMA_GUARD_MODEL_IDS = {
CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B",
"meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B",
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
CoreModelId.llama_guard_4_12b.value: "meta-llama/Llama-Guard-4-12B",
"meta-llama/Llama-Guard-4-12B": "meta-llama/Llama-Guard-4-12B",
}
MODEL_TO_SAFETY_CATEGORIES_MAP = {
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
# Llama Guard 4 uses the same categories as Llama Guard 3
# source: https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard4/12B/MODEL_CARD.md
"meta-llama/Llama-Guard-4-12B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
}
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
<BEGIN UNSAFE CONTENT CATEGORIES>
$categories
<END UNSAFE CONTENT CATEGORIES>
"""
PROMPT_CONVERSATION = """
<BEGIN CONVERSATION>
$conversations
<END CONVERSATION>
"""
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
logger = get_logger(name=__name__, category="safety")
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
model_id = shield.provider_resource_id
if not model_id:
raise ValueError("Llama Guard shield must have a model id")
async def unregister_shield(self, identifier: str) -> None:
# LlamaGuard doesn't need to do anything special for unregistration
# The routing table handles the removal from the registry
pass
async def run_shield(
self,
shield_id: str,
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Unknown shield {shield_id}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != "user":
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model_id in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model_id,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run(messages)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
if model is None:
raise ValueError("Llama Guard moderation requires a model identifier.")
if isinstance(input, list):
messages = input.copy()
else:
messages = [input]
# convert to user messages format with role
messages = [OpenAIUserMessageParam(content=m) for m in messages]
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run_moderation(messages)
class LlamaGuardShield:
def __init__(
self,
model: str,
inference_api: Inference,
excluded_categories: list[str] | None = None,
safety_categories: list[str] | None = None,
):
if excluded_categories is None:
excluded_categories = []
if safety_categories is None:
safety_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
return None
def get_safety_categories(self) -> list[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
final_categories = []
all_categories = self.safety_categories
for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
continue
final_categories.append(f"{cat_code}: {cat}.")
return final_categories
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
raise ValueError("Messages must start with user")
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
messages = messages[1:]
return messages
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages)
else:
shield_input_message = self.build_text_shield_input(messages)
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
conversation = []
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
conversation.append(m)
elif isinstance(m.content, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c)
elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(OpenAIUserMessageParam(content=content))
else:
raise ValueError(f"Unknown content type: {m.content}")
prompt = []
if most_recent_img is not None:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return OpenAIUserMessageParam(content=prompt)
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> RunShieldResponse:
response = response.strip()
if response == SAFE_RESPONSE:
return RunShieldResponse(violation=None)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return RunShieldResponse(violation=None)
return RunShieldResponse(
violation=SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message=CANNED_RESPONSE_TEXT,
metadata={"violation_type": unsafe_code},
),
)
raise ValueError(f"Unexpected response: {response}")
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
"""Create a ModerationObject for either safe or unsafe content.
Args:
model: The model name
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
Returns:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
flagged = False
user_message = None
metadata = {}
# Handle unsafe case
if unsafe_code:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes:
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=user_message,
metadata=metadata,
)
],
)
llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list]
# Update categories for unsafe content
categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
category_scores = {
k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
category_applied_input_types = {
k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
flagged = True
user_message = CANNED_RESPONSE_TEXT
metadata = {"violation_type": unsafe_code_list}
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=user_message,
metadata=metadata,
)
],
)
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
"""Check if content is safe based on response and unsafe code."""
if response.strip().lower().startswith(SAFE_RESPONSE):
return True
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return True
return False
def get_moderation_object(self, response: str) -> ModerationObject:
response = response.strip()
if self.is_content_safe(response):
return self.create_moderation_object(self.model)
unsafe_code = self.check_unsafe_response(response)
if not unsafe_code:
raise ValueError(f"Unexpected response: {response}")
if self.is_content_safe(response, unsafe_code):
return self.create_moderation_object(self.model)
else:
return self.create_moderation_object(self.model, unsafe_code)

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import PromptGuardConfig
async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel, field_validator
class PromptGuardType(Enum):
injection = "injection"
jailbreak = "jailbreak"
class PromptGuardConfig(BaseModel):
guard_type: str = PromptGuardType.injection.value
@classmethod
@field_validator("guard_type")
def validate_guard_type(cls, v):
if v not in [t.value for t in PromptGuardType]:
raise ValueError(f"Unknown prompt guard type: {v}")
return v
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"guard_type": "injection",
}

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ShieldStore,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .config import PromptGuardConfig, PromptGuardType
log = get_logger(name=__name__, category="safety")
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
shield_store: ShieldStore
def __init__(self, config: PromptGuardConfig, _deps) -> None:
self.config = config
async def initialize(self) -> None:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
self.shield = PromptGuardShield(model_dir, self.config)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self,
shield_id: str,
messages: list[OpenAIMessageParam],
params: dict[str, Any],
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Unknown shield {shield_id}")
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
class PromptGuardShield:
def __init__(
self,
model_dir: str,
config: PromptGuardConfig,
threshold: float = 0.9,
temperature: float = 1.0,
):
assert model_dir is not None, "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.config = config
self.temperature = temperature
self.threshold = threshold
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_content_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
log.info(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
)
violation = None
if self.config.guard_type == PromptGuardType.injection.value and (
score_embedded + score_malicious > self.threshold
):
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="Sorry, I cannot do this.",
metadata={
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
},
)
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="Sorry, I cannot do this.",
metadata={
"violation_type": f"prompt_injection:malicious={score_malicious}",
},
)
return RunShieldResponse(violation=violation)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: dict[Api, Any],
):
from .scoring import BasicScoringImpl
impl = BasicScoringImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
await impl.initialize()
return impl

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class BasicScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.core.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import BasicScoringConfig
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [
EqualityScoringFn,
SubsetOfScoringFn,
RegexParserScoringFn,
RegexParserMathResponseScoringFn,
IfEvalScoringFn,
DocVQAScoringFn,
]
class BasicScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__(
self,
config: BasicScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for fn in FIXED_FNS:
impl = fn()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> list[ScoringFn]:
scoring_fn_defs_list = [
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
res = await self.score(
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.docvqa import docvqa
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
"1st": "first",
"2nd": "second",
"3rd": "third",
}
NUMBERS = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = [
"a",
"an",
"the",
"to",
"in",
"from",
"by",
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
PUNCTUATION = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_answer(s: str) -> str:
# process punctuation
for p in PUNCTUATION:
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
s = s.replace(p, "")
else:
s = s.replace(p, " ")
s = PERIOD_STRIP.sub("", s, re.UNICODE)
# process digits and articles
temp_text = s.lower().split()
out_text = []
for word in temp_text:
word = NUMBERS.setdefault(word, word)
if word not in ARTICLES:
out_text.append(word)
# standardize contractions
for word_id, word in enumerate(out_text):
if word in CONTRACTIONS:
out_text[word_id] = CONTRACTIONS[word]
return " ".join(out_text)
class DocVQAScoringFn(RegisteredBaseScoringFn):
"""
docvqa basically matches the generated answer against several allowed
choices, but we need to normalize the answer to avoid penalizing
trivial differences
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
docvqa.identifier: docvqa,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "docvqa",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.equality import equality
class EqualityScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
equality.identifier: equality,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "equality",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert "generated_answer" in input_row, "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
docvqa = ScoringFn(
identifier="basic::docvqa",
description="DocVQA Visual Question & Answer scoring function",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="docvqa",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
ifeval = ScoringFn(
identifier="basic::ifeval",
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="ifeval",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.weighted_average],
),
)

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
regex_parser_math_response = ScoringFn(
identifier="basic::regex_parser_math_response",
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-math-response",
params=RegexParserScoringFnParams(
parsing_regexes=MATH_ANSWER_REGEXES,
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MULTILINGUAL_ANSWER_REGEXES = [
r"The best answer is ",
r"Answer\s*:",
r"Answer\s*:", # Korean invisible character
r"উত্তর\s*:",
r"उत्तर\s*:",
r"উত্তরঃ",
r"উত্তর\s*:",
r"Antwort\s*:",
r"답변\s*:",
r"정답\s*:",
r"\s*:",
r"答案\s*",
r"答案\s*:",
r"\s*",
r"\s*:",
r"答复\s*",
r"答曰\s*",
r"الإجابة:",
r"الجواب:",
r"إجابة:",
r"الإجابة النهائية:",
r"الإجابة الصحيحة:",
r"الإجابة الصحيحة هي:",
r"الإجابة هي:",
r"Respuesta\s*:",
r"Risposta\s*:",
r"答え\s*:",
r"答え\s*",
r"回答\s*:",
r"回答\s*",
r"解答\s*:",
r"Jawaban\s*:",
r"Réponse\s*:",
r"Resposta\s*:",
r"Jibu\s*:",
r"Idahun\s*:",
r"Ìdáhùn\s*:",
r"Idáhùn\s*:",
r"Àmọ̀nà\s*:",
r"Àdáhùn\s*:",
r"Ànúgọ\s*:",
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
regex_parser_multiple_choice_answer = ScoringFn(
identifier="basic::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
subset_of = ScoringFn(
identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.ifeval import (
ifeval,
)
class IfEvalScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn Instruction-Following Eval (IFEval) benchmark
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
ifeval.identifier: ifeval,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
instruction_list = input_row["instruction_id_list"]
generated_answer = input_row["generated_answer"].strip()
is_following_list = []
results = dict(
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
)
for index, instruction_id in enumerate(instruction_list):
instruction_cls = INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
results[instruction_id + "_total"] += 1.0
results[instruction_id.split(":")[0] + "_total"] += 1.0
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
print(clean_input_row)
instruction.build_description(**clean_input_row)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=input_row["prompt"])
if generated_answer and instruction.check_following(generated_answer):
is_following_list.append(True)
results[instruction_id + "_correct"] += 1.0
results[instruction_id.split(":")[0] + "_correct"] += 1.0
else:
is_following_list.append(False)
if len(is_following_list) == 0:
return {
"score": 0.0,
"weight": 0.0,
}
return {
"score": float(sum(is_following_list)) / float(len(is_following_list)),
"weight": float(len(is_following_list)),
}

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
from .fn_defs.regex_parser_math_response import (
regex_parser_math_response,
)
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_math_response.identifier: regex_parser_math_response,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
parsing_regexes = fn_def.params.parsing_regexes
assert len(parsing_regexes) == 1, (
"Only one parsing regex is supported for regex_parser_math_response scoring function."
)
parsing_regexes = fn_def.params.parsing_regexes[0]
normalized_generated_answer = normalize_final_answer(
first_answer(generated_answer),
parsing_regexes,
match_first=True,
)
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
class RegexParserScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
# parse answer according to regex
parsed_answer = None
for regex in fn_def.params.parsing_regexes:
match = re.search(regex, generated_answer)
if match:
parsed_answer = match.group(1)
break
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
subset_of.identifier: subset_of,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "subset_of",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer in generated_answer else 0.0
return {
"score": score,
}

Some files were not shown because too many files have changed in this diff Show more