mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
chore(package): migrate to src/ layout (#3920)
Migrates package structure to src/ layout following Python packaging best practices. All code moved from `llama_stack/` to `src/llama_stack/`. Public API unchanged - imports remain `import llama_stack.*`. Updated build configs, pre-commit hooks, scripts, and GitHub workflows accordingly. All hooks pass, package builds cleanly. **Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
parent
98a5047f9d
commit
471b1b248b
791 changed files with 2983 additions and 456 deletions
5
src/llama_stack/providers/inline/__init__.py
Normal file
5
src/llama_stack/providers/inline/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
5
src/llama_stack/providers/inline/agents/__init__.py
Normal file
5
src/llama_stack/providers/inline/agents/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceAgentsImplConfig,
|
||||
deps: dict[Api, Any],
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.vector_io],
|
||||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
deps[Api.conversations],
|
||||
policy,
|
||||
telemetry_enabled,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
File diff suppressed because it is too large
Load diff
383
src/llama_stack/providers/inline/agents/meta_reference/agents.py
Normal file
383
src/llama_stack/providers/inline/agents/meta_reference/agents.py
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Order,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceAgentsImplConfig,
|
||||
inference_api: Inference,
|
||||
vector_io_api: VectorIO,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
conversations_api: Conversations,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.conversations_api = conversations_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence.agent_state)
|
||||
self.responses_store = ResponsesStore(self.config.persistence.responses, self.policy)
|
||||
await self.responses_store.initialize()
|
||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||
inference_api=self.inference_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
safety_api=self.safety_api,
|
||||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
created_at = datetime.now(UTC)
|
||||
|
||||
agent_info = AgentInfo(
|
||||
**agent_config.model_dump(),
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
# Store the agent info
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=agent_info.model_dump_json(),
|
||||
)
|
||||
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_info_json = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
if not agent_info_json:
|
||||
raise ValueError(f"Could not find agent info for {agent_id}")
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
||||
|
||||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _create_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
tool_responses=tool_responses,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return self._continue_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _continue_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
||||
for step in turn.steps:
|
||||
if step.step_id == step_id:
|
||||
return AgentStepResponse(step=step)
|
||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
return Session(
|
||||
session_name=session_info.session_name,
|
||||
session_id=session_id,
|
||||
turns=turns,
|
||||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
# Delete turns first, then the session
|
||||
await agent.storage.delete_session_turns(session_id)
|
||||
await agent.storage.delete_session(session_id)
|
||||
|
||||
async def delete_agent(self, agent_id: str) -> None:
|
||||
# First get all sessions for this agent
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
|
||||
# Delete all sessions
|
||||
for session in sessions:
|
||||
await self.delete_agents_session(agent_id, session.session_id)
|
||||
|
||||
# Finally delete the agent itself
|
||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
||||
agent_list: list[Agent] = []
|
||||
for agent_key in agent_keys:
|
||||
agent_id = agent_key.split(":")[1]
|
||||
|
||||
# Get the agent info using the key
|
||||
agent_info_json = await self.persistence_store.get(agent_key)
|
||||
if not agent_info_json:
|
||||
logger.error(f"Could not find agent info for key {agent_key}")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
agent_list.append(
|
||||
Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
created_at=agent_info.created_at,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Agent objects to dictionaries
|
||||
agent_dicts = [agent.model_dump() for agent in agent_list]
|
||||
return paginate_records(agent_dicts, start_index, limit)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
chat_agent = await self._get_agent_impl(agent_id)
|
||||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=chat_agent.created_at,
|
||||
)
|
||||
return agent
|
||||
|
||||
async def list_agent_sessions(
|
||||
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
||||
) -> PaginatedResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
# Convert Session objects to dictionaries
|
||||
session_dicts = [session.model_dump() for session in sessions]
|
||||
return paginate_records(session_dicts, start_index, limit)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
# OpenAI responses
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
model,
|
||||
instructions,
|
||||
previous_response_id,
|
||||
conversation,
|
||||
store,
|
||||
stream,
|
||||
temperature,
|
||||
text,
|
||||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
guardrails,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||
response_id, after, before, include, limit, order
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> None:
|
||||
return await self.openai_responses_impl.delete_openai_response(response_id)
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
|
||||
|
||||
class AgentPersistenceConfig(BaseModel):
|
||||
"""Nested persistence configuration for agents."""
|
||||
|
||||
agent_state: KVStoreReference
|
||||
responses: ResponsesStoreReference
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence: AgentPersistenceConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"persistence": {
|
||||
"agent_state": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
).model_dump(exclude_none=True),
|
||||
"responses": ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
owner: User | None = None
|
||||
identifier: str | None = None
|
||||
type: str = "session"
|
||||
|
||||
|
||||
class AgentInfo(AgentConfig):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
self.kvstore = kvstore
|
||||
self.policy = policy
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
user = get_authenticated_user()
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(UTC),
|
||||
owner=user,
|
||||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError("create", session_info, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
session_info = AgentSessionInfo(**json.loads(value))
|
||||
|
||||
# Check access to session
|
||||
if not self._check_session_access(session_info):
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
|
||||
session_info = await self.get_session_if_accessible(session_id)
|
||||
if session_info is None:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
session_info.vector_db_id = vector_db_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
turns = []
|
||||
for value in values:
|
||||
try:
|
||||
turn = Turn(**json.loads(value))
|
||||
turns.append(turn)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
return Turn(**json.loads(value))
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
||||
async def list_sessions(self) -> list[Session]:
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:",
|
||||
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
sessions = []
|
||||
for value in values:
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if "turn_id" in data:
|
||||
continue
|
||||
|
||||
session_info = Session(**data)
|
||||
sessions.append(session_info)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing session info: {e}")
|
||||
continue
|
||||
return sessions
|
||||
|
||||
async def delete_session_turns(self, session_id: str) -> None:
|
||||
"""Delete all turns and their associated data for a session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session whose turns should be deleted.
|
||||
"""
|
||||
turns = await self.get_session_turns(session_id)
|
||||
for turn in turns:
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
||||
|
||||
async def delete_session(self, session_id: str) -> None:
|
||||
"""Delete a session and all its associated turns.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session does not exist.
|
||||
"""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,424 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_guardrail_ids,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: ListOpenAIResponseInputItem
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
safety_api: Safety,
|
||||
conversations_api: Conversations,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.conversations_api = conversations_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
vector_io_api=vector_io_api,
|
||||
)
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
return new_input_items
|
||||
|
||||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
conversation: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion, tool context)
|
||||
"""
|
||||
tool_context = ToolContext(tools)
|
||||
if previous_response_id:
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
)
|
||||
all_input = await self._prepend_previous_response(input, previous_response)
|
||||
|
||||
if previous_response.messages:
|
||||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
tool_context.recover_tools_from_previous_response(previous_response)
|
||||
elif conversation is not None:
|
||||
conversation_items = await self.conversations_api.list_items(conversation, order="asc")
|
||||
|
||||
# Use stored messages as source of truth (like previous_response.messages)
|
||||
stored_messages = await self.responses_store.get_conversation_messages(conversation)
|
||||
|
||||
all_input = input
|
||||
if not conversation_items.data:
|
||||
# First turn - just convert the new input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
else:
|
||||
if not stored_messages:
|
||||
all_input = conversation_items.data
|
||||
if isinstance(input, str):
|
||||
all_input.append(
|
||||
OpenAIResponseMessage(
|
||||
role="user", content=[OpenAIResponseInputMessageContentText(text=input)]
|
||||
)
|
||||
)
|
||||
else:
|
||||
all_input.extend(input)
|
||||
else:
|
||||
all_input = input
|
||||
|
||||
messages = stored_messages or []
|
||||
new_messages = await convert_response_input_to_chat_messages(all_input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
return all_input, messages, tool_context
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return response_with_input.to_response_object()
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.responses_store.list_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
:returns: An ListOpenAIResponseInputItem.
|
||||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
if conversation is not None:
|
||||
if previous_response_id is not None:
|
||||
raise ValueError(
|
||||
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
|
||||
)
|
||||
|
||||
if not conversation.startswith("conv_"):
|
||||
raise InvalidConversationIdError(conversation)
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
conversation=conversation,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
final_response = None
|
||||
final_event_type = None
|
||||
failed_response = None
|
||||
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
f"Earlier response from {final_event_type}"
|
||||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if failed_response is not None:
|
||||
error_message = (
|
||||
failed_response.error.message
|
||||
if failed_response and failed_response.error
|
||||
else "Response stream failed without error details"
|
||||
)
|
||||
raise RuntimeError(f"OpenAI response failed: {error_message}")
|
||||
|
||||
if final_response is None:
|
||||
raise ValueError("The response stream never reached a terminal state")
|
||||
return final_response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id, conversation
|
||||
)
|
||||
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
inputs=all_input,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp_{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
inference_api=self.inference_api,
|
||||
ctx=ctx,
|
||||
response_id=response_id,
|
||||
created_at=created_at,
|
||||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
guardrail_ids=guardrail_ids,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
final_response = None
|
||||
failed_response = None
|
||||
|
||||
output_items = []
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
final_response = stream_chunk.response
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if stream_chunk.type == "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
|
||||
# Store and sync before yielding terminal events
|
||||
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
||||
if (
|
||||
stream_chunk.type in {"response.completed", "response.incomplete"}
|
||||
and final_response
|
||||
and failed_response is None
|
||||
):
|
||||
messages_to_store = list(
|
||||
filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
|
||||
)
|
||||
if store:
|
||||
# TODO: we really should work off of output_items instead of "final_messages"
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=all_input,
|
||||
messages=messages_to_store,
|
||||
)
|
||||
|
||||
if conversation:
|
||||
await self._sync_response_to_conversation(conversation, input, output_items)
|
||||
await self.responses_store.store_conversation_messages(conversation, messages_to_store)
|
||||
|
||||
yield stream_chunk
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _sync_response_to_conversation(
|
||||
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
conversation_items = []
|
||||
|
||||
if isinstance(input, str):
|
||||
conversation_items.append(
|
||||
OpenAIResponseMessage(role="user", content=[OpenAIResponseInputMessageContentText(text=input)])
|
||||
)
|
||||
elif isinstance(input, list):
|
||||
conversation_items.extend(input)
|
||||
|
||||
conversation_items.extend(output_items)
|
||||
|
||||
adapter = TypeAdapter(list[ConversationItem])
|
||||
validated_items = adapter.validate_python(conversation_items)
|
||||
await self.conversations_api.add_items(conversation_id, validated_items)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,449 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallSearching,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIImageURL,
|
||||
OpenAIToolMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
vector_io_api: VectorIO,
|
||||
):
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.vector_io_api = vector_io_api
|
||||
|
||||
async def execute_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
return
|
||||
|
||||
# Emit progress events for tool execution start
|
||||
async for event_result in self._emit_progress_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Execute the actual tool call
|
||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
self,
|
||||
query: str,
|
||||
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||
) -> ToolInvocationResult:
|
||||
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||
search_results = []
|
||||
|
||||
# Create search tasks for all vector stores
|
||||
async def search_single_store(vector_store_id):
|
||||
try:
|
||||
search_response = await self.vector_io_api.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=response_file_search_tool.filters,
|
||||
max_num_results=response_file_search_tool.max_num_results,
|
||||
ranking_options=response_file_search_tool.ranking_options,
|
||||
rewrite_query=False,
|
||||
)
|
||||
return search_response.data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
return []
|
||||
|
||||
# Run all searches in parallel using gather
|
||||
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
||||
all_results = await asyncio.gather(*search_tasks)
|
||||
|
||||
# Flatten results
|
||||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
)
|
||||
|
||||
unique_files = set()
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
# Get file_id from attributes if result_item.file_id is empty
|
||||
file_id = result_item.file_id or (
|
||||
result_item.attributes.get("document_id") if result_item.attributes else None
|
||||
)
|
||||
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
|
||||
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
unique_files.add(file_id)
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
||||
citation_instruction = ""
|
||||
if unique_files:
|
||||
citation_instruction = (
|
||||
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
|
||||
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||
)
|
||||
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
|
||||
)
|
||||
)
|
||||
|
||||
# handling missing attributes for old versions
|
||||
citation_files = {}
|
||||
for result in search_results:
|
||||
file_id = result.file_id
|
||||
if not file_id and result.attributes:
|
||||
file_id = result.attributes.get("document_id")
|
||||
|
||||
filename = result.filename
|
||||
if not filename and result.attributes:
|
||||
filename = result.attributes.get("filename")
|
||||
if not filename:
|
||||
filename = "unknown"
|
||||
|
||||
citation_files[file_id] = filename
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
"citation_files": citation_files,
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit_progress_events(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function_name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# For file search, emit searching event
|
||||
if function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
function_name: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[Exception | None, any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = mcp_tool_to_server[function_name]
|
||||
attributes = {
|
||||
"server_label": mcp_tool.server_label,
|
||||
"server_url": mcp_tool.server_url,
|
||||
"tool_name": function_name,
|
||||
}
|
||||
async with tracing.span("invoke_mcp_tool", attributes):
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
async with tracing.span("knowledge_search", {}):
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
else:
|
||||
attributes = {
|
||||
"tool_name": function_name,
|
||||
}
|
||||
async with tracing.span("invoke_tool", attributes):
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
return error_exc, result
|
||||
|
||||
async def _emit_completion_events(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
||||
async def _build_result_messages(
|
||||
self,
|
||||
function,
|
||||
tool_call_id: str,
|
||||
item_id: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
result: any,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[any, any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
# Build output message
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=item_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=item_id,
|
||||
status="completed",
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=item_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if result and "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
message.results.append(
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||
file_id=doc_id,
|
||||
filename=doc_id,
|
||||
text=text,
|
||||
score=score,
|
||||
attributes={},
|
||||
)
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
# Build input message
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseTool,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
"""Result of streaming tool execution."""
|
||||
|
||||
stream_event: OpenAIResponseObjectStream | None = None
|
||||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
citation_files: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionResult:
|
||||
"""Result of processing streaming chat completion chunks."""
|
||||
|
||||
response_id: str
|
||||
content: list[str]
|
||||
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
||||
created: int
|
||||
model: str
|
||||
finish_reason: str
|
||||
message_item_id: str # For streaming events
|
||||
tool_call_item_ids: dict[int, str] # For streaming events
|
||||
content_part_emitted: bool # Tracking state
|
||||
|
||||
@property
|
||||
def content_text(self) -> str:
|
||||
"""Get joined content as string."""
|
||||
return "".join(self.content)
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if there are any tool calls."""
|
||||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class ToolContext(BaseModel):
|
||||
"""Holds information about tools from this and (if relevant)
|
||||
previous response in order to facilitate reuse of previous
|
||||
listings where appropriate."""
|
||||
|
||||
# tools argument passed into current request:
|
||||
current_tools: list[OpenAIResponseInputTool]
|
||||
# reconstructed map of tool -> mcp server from previous response:
|
||||
previous_tools: dict[str, OpenAIResponseInputToolMCP]
|
||||
# reusable mcp-list-tools objects from previous response:
|
||||
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
|
||||
# tool arguments from current request that still need to be processed:
|
||||
tools_to_process: list[OpenAIResponseInputTool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_tools: list[OpenAIResponseInputTool] | None,
|
||||
):
|
||||
super().__init__(
|
||||
current_tools=current_tools or [],
|
||||
previous_tools={},
|
||||
previous_tool_listings=[],
|
||||
tools_to_process=current_tools or [],
|
||||
)
|
||||
|
||||
def recover_tools_from_previous_response(
|
||||
self,
|
||||
previous_response: OpenAIResponseObject,
|
||||
):
|
||||
"""Determine which mcp_list_tools objects from previous response we can reuse."""
|
||||
|
||||
if self.current_tools and previous_response.tools:
|
||||
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
|
||||
for tool in previous_response.tools:
|
||||
if isinstance(tool, OpenAIResponseToolMCP):
|
||||
previous_tools_by_label[tool.server_label] = tool
|
||||
# collect tool definitions which are the same in current and previous requests:
|
||||
tools_to_process = []
|
||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
for tool in self.current_tools:
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||
previous_tool = previous_tools_by_label[tool.server_label]
|
||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||
matched[tool.server_label] = tool
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
# tools that are not the same or were not previously defined need to be processed:
|
||||
self.tools_to_process = tools_to_process
|
||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||
self.previous_tool_listings = [
|
||||
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
|
||||
]
|
||||
# reconstruct the tool to server mappings that can be reused:
|
||||
for listing in self.previous_tool_listings:
|
||||
definition = matched[listing.server_label]
|
||||
for tool in listing.tools:
|
||||
self.previous_tools[tool.name] = definition
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.current_tools:
|
||||
return []
|
||||
|
||||
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
|
||||
if isinstance(tool, OpenAIResponseInputToolWebSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFileSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFunction):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP):
|
||||
return OpenAIResponseToolMCP(
|
||||
server_label=tool.server_label,
|
||||
allowed_tools=tool.allowed_tools,
|
||||
)
|
||||
|
||||
return [convert_tool(tool) for tool in self.current_tools]
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
tool_context: ToolContext | None
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
response_tools: list[OpenAIResponseInputTool] | None,
|
||||
temperature: float | None,
|
||||
response_format: OpenAIResponseFormatParam,
|
||||
tool_context: ToolContext,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=response_tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
self.approval_responses = {
|
||||
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
|
||||
}
|
||||
|
||||
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
|
||||
request = self._approval_request(tool_name, arguments)
|
||||
return self.approval_responses.get(request.id, None) if request else None
|
||||
|
||||
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
|
||||
for request in self.approval_requests:
|
||||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.tool_context:
|
||||
return []
|
||||
return self.tool_context.available_tools()
|
||||
|
|
@ -0,0 +1,365 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
choice: OpenAIChoice,
|
||||
citation_files: dict[str, str] | None = None,
|
||||
*,
|
||||
message_id: str | None = None,
|
||||
) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=message_id or f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
|
||||
:param input: The input to convert
|
||||
:param previous_messages: Optional previous messages to check for function_call references
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
# extract all OpenAIResponseInputFunctionToolCallOutput items
|
||||
# so their corresponding OpenAIToolMessageParam instances can
|
||||
# be added immediately following the corresponding
|
||||
# OpenAIAssistantMessageParam
|
||||
tool_call_results = {}
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
# skip as these have been extracted and inserted in order
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
if input_item.call_id in tool_call_results:
|
||||
messages.append(tool_call_results[input_item.call_id])
|
||||
del tool_call_results[input_item.call_id]
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||
# the tool list will be handled separately
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
|
||||
input_item, OpenAIResponseMCPApprovalResponse
|
||||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
# Skip user messages that duplicate the last user message in previous_messages
|
||||
# This handles cases where input includes context for function_call_outputs
|
||||
if previous_messages and input_item.role == "user":
|
||||
last_user_msg = None
|
||||
for msg in reversed(previous_messages):
|
||||
if isinstance(msg, OpenAIUserMessageParam):
|
||||
last_user_msg = msg
|
||||
break
|
||||
if last_user_msg:
|
||||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
if len(tool_call_results):
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
previous_call_ids = _extract_tool_call_ids(previous_messages)
|
||||
for call_id in list(tool_call_results.keys()):
|
||||
if call_id in previous_call_ids:
|
||||
# Valid: this output references a call from previous messages
|
||||
# Add the tool message
|
||||
messages.append(tool_call_results[call_id])
|
||||
del tool_call_results[call_id]
|
||||
|
||||
# If still have unpaired outputs, error
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
|
||||
"""Extract all tool_call IDs from messages."""
|
||||
call_ids = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, OpenAIAssistantMessageParam):
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
# tool_call is a Pydantic model, use attribute access
|
||||
call_ids.add(tool_call.id)
|
||||
return call_ids
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
|
||||
"""Get the appropriate OpenAI message parameter type for a given role."""
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
text: str, citation_files: dict[str, str]
|
||||
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||
"""Extract citation markers from text and create annotations
|
||||
|
||||
Args:
|
||||
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||
citation_files: Dictionary mapping file_id to filename
|
||||
|
||||
Returns:
|
||||
Tuple of (annotations_list, clean_text_without_markers)
|
||||
"""
|
||||
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||
|
||||
annotations = []
|
||||
parts = []
|
||||
total_len = 0
|
||||
last_end = 0
|
||||
|
||||
for m in file_id_regex.finditer(text):
|
||||
# segment before the marker
|
||||
prefix = text[last_end : m.start()]
|
||||
|
||||
# drop one space if it exists (since marker is at sentence end)
|
||||
if prefix.endswith(" "):
|
||||
prefix = prefix[:-1]
|
||||
|
||||
parts.append(prefix)
|
||||
total_len += len(prefix)
|
||||
|
||||
fid = m.group(1)
|
||||
if fid in citation_files:
|
||||
annotations.append(
|
||||
OpenAIResponseAnnotationFileCitation(
|
||||
file_id=fid,
|
||||
filename=citation_files[fid],
|
||||
index=total_len, # index points to punctuation
|
||||
)
|
||||
)
|
||||
|
||||
last_end = m.end()
|
||||
|
||||
parts.append(text[last_end:])
|
||||
cleaned_text = "".join(parts)
|
||||
return annotations, cleaned_text
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
) -> bool:
|
||||
if not tool_call.function:
|
||||
return False
|
||||
for t in tools:
|
||||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run guardrails against messages and return violation message if blocked."""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
if matching_shields:
|
||||
model_id = matching_shields[0].provider_resource_id
|
||||
model_ids.append(model_id)
|
||||
else:
|
||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||
|
||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||
responses = await asyncio.gather(*guardrail_tasks)
|
||||
|
||||
for response in responses:
|
||||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
message += f" (flagged for: {', '.join(flagged_categories)})"
|
||||
if violation_type:
|
||||
message += f" (violation type: {', '.join(violation_type)})"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
if not guardrails:
|
||||
return []
|
||||
|
||||
guardrail_ids = []
|
||||
for guardrail in guardrails:
|
||||
if isinstance(guardrail, str):
|
||||
guardrail_ids.append(guardrail)
|
||||
elif isinstance(guardrail, ResponseGuardrailSpec):
|
||||
guardrail_ids.append(guardrail.type)
|
||||
else:
|
||||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||
|
||||
return guardrail_ids
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
def __init__(self, violation: SafetyViolation):
|
||||
self.violation = violation
|
||||
super().__init__(violation.user_message)
|
||||
|
||||
|
||||
class ShieldRunnerMixin:
|
||||
def __init__(
|
||||
self,
|
||||
safety_api: Safety,
|
||||
input_shields: list[str] | None = None,
|
||||
output_shields: list[str] | None = None,
|
||||
):
|
||||
self.safety_api = safety_api
|
||||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
params={},
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
violation = response.violation
|
||||
if violation.violation_level == ViolationLevel.ERROR:
|
||||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
log.warning(f"[Warn]{identifier} raised a warning")
|
||||
5
src/llama_stack/providers/inline/batches/__init__.py
Normal file
5
src/llama_stack/providers/inline/batches/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .batches import ReferenceBatchesImpl
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
inference_api: Inference | None = deps.get(Api.inference)
|
||||
files_api: Files | None = deps.get(Api.files)
|
||||
models_api: Models | None = deps.get(Api.models)
|
||||
|
||||
if inference_api is None:
|
||||
raise ValueError("Inference API is required but not provided in dependencies")
|
||||
if files_api is None:
|
||||
raise ValueError("Files API is required but not provided in dependencies")
|
||||
if models_api is None:
|
||||
raise ValueError("Models API is required but not provided in dependencies")
|
||||
|
||||
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
679
src/llama_stack/providers/inline/batches/reference/batches.py
Normal file
679
src/llama_stack/providers/inline/batches/reference/batches.py
Normal file
|
|
@ -0,0 +1,679 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
|
||||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
BATCH_PREFIX = "batch:"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncBytesIO:
|
||||
"""
|
||||
Async-compatible BytesIO wrapper to allow async file-like operations.
|
||||
|
||||
We use this when uploading files to the Files API, as it expects an
|
||||
async file-like object.
|
||||
"""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self._buffer = BytesIO(data)
|
||||
|
||||
async def read(self, n=-1):
|
||||
return self._buffer.read(n)
|
||||
|
||||
async def seek(self, pos, whence=0):
|
||||
return self._buffer.seek(pos, whence)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._buffer.close()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._buffer, name)
|
||||
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
line_num: int
|
||||
custom_id: str
|
||||
method: str
|
||||
url: str
|
||||
body: dict[str, Any]
|
||||
|
||||
|
||||
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
|
||||
"""Convert a message dictionary to OpenAIMessageParam based on role."""
|
||||
role = msg.get("role")
|
||||
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
|
||||
class ReferenceBatchesImpl(Batches):
|
||||
"""Reference implementation of the Batches API.
|
||||
|
||||
This implementation processes batch files by making individual requests
|
||||
to the inference API and generates output files with results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReferenceBatchesImplConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
models_api: Models,
|
||||
kvstore: KVStore,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.kvstore = kvstore
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.models_api = models_api
|
||||
self._processing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
|
||||
self._update_batch_lock = asyncio.Lock()
|
||||
|
||||
# this is to allow tests to disable background processing
|
||||
self.process_batches = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# TODO: start background processing of existing tasks
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the batches provider."""
|
||||
if self._processing_tasks:
|
||||
# don't cancel tasks - just let them stop naturally on shutdown
|
||||
# cancelling would mark batches as "cancelled" in the database
|
||||
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
||||
|
||||
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BatchObject:
|
||||
"""
|
||||
Create a new batch for processing multiple API requests.
|
||||
|
||||
This implementation provides optional idempotency: when an idempotency key
|
||||
(idempotency_key) is provided, a deterministic ID is generated based on the input
|
||||
parameters. If a batch with the same parameters already exists, it will be
|
||||
returned instead of creating a duplicate. Without an idempotency key,
|
||||
each request creates a new batch with a unique ID.
|
||||
|
||||
Args:
|
||||
input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
endpoint: The endpoint to be used for all requests in the batch.
|
||||
completion_window: The time window within which the batch should be processed.
|
||||
metadata: Optional metadata for the batch.
|
||||
idempotency_key: Optional idempotency key for enabling idempotent behavior.
|
||||
|
||||
Returns:
|
||||
The created or existing batch object.
|
||||
"""
|
||||
|
||||
# Error handling by levels -
|
||||
# 0. Input param handling, results in 40x errors before processing, e.g.
|
||||
# - Wrong completion_window
|
||||
# - Invalid metadata types
|
||||
# - Unknown endpoint
|
||||
# -> no batch created
|
||||
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||
# - input_file_id missing
|
||||
# - invalid json in file
|
||||
# - missing custom_id, method, url, body
|
||||
# - invalid model
|
||||
# - streaming
|
||||
# -> batch created, validation sends to failed status
|
||||
# 2. Processing errors, result in error_file_id entries, e.g.
|
||||
# - Any error returned from inference endpoint
|
||||
# -> batch created, goes to completed status
|
||||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
raise ValueError(
|
||||
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# For idempotent requests, use the idempotency key for the batch ID
|
||||
# This ensures the same key always maps to the same batch ID,
|
||||
# allowing us to detect parameter conflicts
|
||||
if idempotency_key is not None:
|
||||
hash_input = idempotency_key.encode("utf-8")
|
||||
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
|
||||
batch_id = f"batch_{hash_digest}"
|
||||
|
||||
try:
|
||||
existing_batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
if (
|
||||
existing_batch.input_file_id != input_file_id
|
||||
or existing_batch.endpoint != endpoint
|
||||
or existing_batch.completion_window != completion_window
|
||||
or existing_batch.metadata != metadata
|
||||
):
|
||||
raise ConflictError(
|
||||
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
|
||||
"Either use a new idempotency key or ensure all parameters match the original request."
|
||||
)
|
||||
|
||||
logger.info(f"Returning existing batch with ID: {batch_id}")
|
||||
return existing_batch
|
||||
except ResourceNotFoundError:
|
||||
# Batch doesn't exist, continue with creation
|
||||
pass
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
batch = BatchObject(
|
||||
id=batch_id,
|
||||
object="batch",
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
completion_window=completion_window,
|
||||
status="validating",
|
||||
created_at=current_time,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||
logger.info(f"Created new batch with ID: {batch_id}")
|
||||
|
||||
if self.process_batches:
|
||||
task = asyncio.create_task(self._process_batch(batch_id))
|
||||
self._processing_tasks[batch_id] = task
|
||||
|
||||
return batch
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress."""
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
if batch.status in ["cancelled", "cancelling"]:
|
||||
return batch
|
||||
|
||||
if batch.status in ["completed", "failed", "expired"]:
|
||||
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||
|
||||
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
|
||||
if batch_id in self._processing_tasks:
|
||||
self._processing_tasks[batch_id].cancel()
|
||||
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||
|
||||
return await self.retrieve_batch(batch_id)
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""
|
||||
List all batches, eventually only for the current user.
|
||||
|
||||
With no notion of user, we return all batches.
|
||||
"""
|
||||
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
|
||||
|
||||
batches = []
|
||||
for batch_data in batch_values:
|
||||
if batch_data:
|
||||
batches.append(BatchObject.model_validate_json(batch_data))
|
||||
|
||||
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||
|
||||
start_idx = 0
|
||||
if after:
|
||||
for i, batch in enumerate(batches):
|
||||
if batch.id == after:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
page_batches = batches[start_idx : start_idx + limit]
|
||||
has_more = (start_idx + limit) < len(batches)
|
||||
|
||||
first_id = page_batches[0].id if page_batches else None
|
||||
last_id = page_batches[-1].id if page_batches else None
|
||||
|
||||
return ListBatchesResponse(
|
||||
data=page_batches,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch."""
|
||||
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||
if not batch_data:
|
||||
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||
|
||||
return BatchObject.model_validate_json(batch_data)
|
||||
|
||||
async def _update_batch(self, batch_id: str, **updates) -> None:
|
||||
"""Update batch fields in kvstore."""
|
||||
async with self._update_batch_lock:
|
||||
try:
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||
logger.info(
|
||||
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
|
||||
)
|
||||
return
|
||||
|
||||
if "errors" in updates:
|
||||
updates["errors"] = updates["errors"].model_dump()
|
||||
|
||||
batch_dict = batch.model_dump()
|
||||
batch_dict.update(updates)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update batch {batch_id}: {e}")
|
||||
|
||||
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
|
||||
"""
|
||||
Read & validate input, return errors and valid input.
|
||||
|
||||
Validation of
|
||||
- input_file_id existance
|
||||
- valid json
|
||||
- custom_id, method, url, body presence and valid
|
||||
- no streaming
|
||||
"""
|
||||
requests: list[BatchRequest] = []
|
||||
errors: list[BatchError] = []
|
||||
try:
|
||||
await self.files_api.openai_retrieve_file(batch.input_file_id)
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=None,
|
||||
message=f"Cannot find file {batch.input_file_id}.",
|
||||
param="input_file_id",
|
||||
)
|
||||
)
|
||||
return errors, requests
|
||||
|
||||
# TODO(SECURITY): do something about large files
|
||||
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||
file_content = file_content_response.body.decode("utf-8")
|
||||
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||
if line.strip(): # skip empty lines
|
||||
try:
|
||||
request = json.loads(line)
|
||||
|
||||
if not isinstance(request, dict):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message="Each line must be a JSON dictionary object",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
valid = True
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("custom_id", str, "string"),
|
||||
("method", str, "string"),
|
||||
("url", str, "string"),
|
||||
("body", dict, "JSON dictionary object"),
|
||||
]:
|
||||
if param not in request:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="missing_required_parameter",
|
||||
line=line_num,
|
||||
message=f"Missing required parameter: {param}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(request[param], expected_type):
|
||||
param_name = "URL" if param == "url" else param.capitalize()
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param_name} must be a {type_string}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_url",
|
||||
line=line_num,
|
||||
message="URL provided for this request does not match the batch endpoint",
|
||||
param="url",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (body := request.get("body")) and isinstance(body, dict):
|
||||
if body.get("stream", False):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="streaming_unsupported",
|
||||
line=line_num,
|
||||
message="Streaming is not supported in batch processing",
|
||||
param="body.stream",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params: list[tuple[str, Any, str]] = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
elif batch.endpoint == "/v1/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
else: # /v1/embeddings
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("input", (str, list), "a string or array of strings"),
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} parameter is required",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(body[param], expected_type):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} must be {type_string}",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if "model" in body and isinstance(body["model"], str):
|
||||
try:
|
||||
await self.models_api.get_model(body["model"])
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="model_not_found",
|
||||
line=line_num,
|
||||
message=f"Model '{body['model']}' does not exist or is not supported",
|
||||
param="body.model",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if valid:
|
||||
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
||||
requests.append(
|
||||
BatchRequest(
|
||||
line_num=line_num,
|
||||
url=url,
|
||||
method=request["method"],
|
||||
custom_id=request["custom_id"],
|
||||
body=body,
|
||||
),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_json_line",
|
||||
line=line_num,
|
||||
message="This line is not parseable as valid JSON.",
|
||||
)
|
||||
)
|
||||
|
||||
return errors, requests
|
||||
|
||||
async def _process_batch(self, batch_id: str) -> None:
|
||||
"""Background task to process a batch of requests."""
|
||||
try:
|
||||
logger.info(f"Starting batch processing for {batch_id}")
|
||||
async with self._batch_semaphore: # semaphore to limit concurrency
|
||||
logger.info(f"Acquired semaphore for batch {batch_id}")
|
||||
await self._process_batch_impl(batch_id)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Batch processing cancelled for {batch_id}")
|
||||
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
|
||||
except Exception as e:
|
||||
logger.error(f"Batch processing failed for {batch_id}: {e}")
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
|
||||
)
|
||||
finally:
|
||||
self._processing_tasks.pop(batch_id, None)
|
||||
|
||||
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||
"""Implementation of batch processing logic."""
|
||||
errors: list[BatchError] = []
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
errors, requests = await self._validate_input(batch)
|
||||
if errors:
|
||||
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
|
||||
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
|
||||
|
||||
total_requests = len(requests)
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="in_progress",
|
||||
request_counts={"total": total_requests, "completed": 0, "failed": 0},
|
||||
)
|
||||
|
||||
error_results = []
|
||||
success_results = []
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
|
||||
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
|
||||
|
||||
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
|
||||
|
||||
for result in chunk_results:
|
||||
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
|
||||
failed_count += 1
|
||||
error_results.append(result)
|
||||
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
|
||||
completed_count += 1
|
||||
success_results.append(result)
|
||||
else: # unexpected result
|
||||
failed_count += 1
|
||||
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
|
||||
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
|
||||
)
|
||||
|
||||
if errors:
|
||||
await self._update_batch(
|
||||
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
output_file_id = await self._create_output_file(batch_id, success_results, "success")
|
||||
await self._update_batch(batch_id, output_file_id=output_file_id)
|
||||
|
||||
error_file_id = await self._create_output_file(batch_id, error_results, "error")
|
||||
await self._update_batch(batch_id, error_file_id=error_file_id)
|
||||
|
||||
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
|
||||
|
||||
logger.info(
|
||||
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
|
||||
)
|
||||
except Exception as e:
|
||||
# note: errors is empty at this point, so we don't lose anything by ignoring it
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
|
||||
)
|
||||
|
||||
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
|
||||
"""Process a single request from the batch."""
|
||||
request_id = f"batch_req_{batch_id}_{request.line_num}"
|
||||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
if request.url == "/v1/chat/completions":
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
|
||||
chat_response = await self.inference_api.openai_chat_completion(chat_params)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
elif request.url == "/v1/completions":
|
||||
completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
|
||||
completion_response = await self.inference_api.openai_completion(completion_params)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(completion_response, "model_dump_json"), (
|
||||
"Completion response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id,
|
||||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/embeddings
|
||||
embeddings_response = await self.inference_api.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(**request.body)
|
||||
)
|
||||
assert hasattr(embeddings_response, "model_dump_json"), (
|
||||
"Embeddings response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": embeddings_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"error": {"type": "request_failed", "message": str(e)},
|
||||
}
|
||||
|
||||
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
|
||||
"""
|
||||
Create an output file with batch results.
|
||||
|
||||
This function filters results based on the specified file_type
|
||||
and uploads the file to the Files API.
|
||||
"""
|
||||
output_lines = [json.dumps(result) for result in results]
|
||||
|
||||
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
|
||||
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
|
||||
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||
return uploaded_file.id
|
||||
40
src/llama_stack/providers/inline/batches/reference/config.py
Normal file
40
src/llama_stack/providers/inline/batches/reference/config.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class ReferenceBatchesImplConfig(BaseModel):
|
||||
"""Configuration for the Reference Batches implementation."""
|
||||
|
||||
kvstore: KVStoreReference = Field(
|
||||
description="Configuration for the key-value store backend.",
|
||||
)
|
||||
|
||||
max_concurrent_batches: int = Field(
|
||||
default=1,
|
||||
description="Maximum number of concurrent batches to process simultaneously.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
max_concurrent_requests_per_batch: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of concurrent requests to process per batch.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
# TODO: add a max requests per second rate limiter
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||
return {
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="batches",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
5
src/llama_stack/providers/inline/datasetio/__init__.py
Normal file
5
src/llama_stack/providers/inline/datasetio/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
impl = LocalFSDatasetIOImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
23
src/llama_stack/providers/inline/datasetio/localfs/config.py
Normal file
23
src/llama_stack/providers/inline/datasetio/localfs/config.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="datasetio::localfs",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
113
src/llama_stack/providers/inline/datasetio/localfs/datasetio.py
Normal file
113
src/llama_stack/providers/inline/datasetio/localfs/datasetio.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
DATASETS_PREFIX = "localfs_datasets:"
|
||||
|
||||
|
||||
class PandasDataframeDataset:
|
||||
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_def = dataset_def
|
||||
self.df = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
||||
if isinstance(idx, slice):
|
||||
return self.df.iloc[idx].to_dict(orient="records")
|
||||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
async def load(self) -> None:
|
||||
if self.df is not None:
|
||||
return
|
||||
|
||||
if self.dataset_def.source.type == "uri":
|
||||
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
import pandas
|
||||
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
|
||||
if self.df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
||||
|
||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
|
||||
self.config = config
|
||||
# local registry for keeping track of datasets within the provider
|
||||
self.dataset_infos = {}
|
||||
self.kvstore = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
# Load existing datasets from kvstore
|
||||
start_key = DATASETS_PREFIX
|
||||
end_key = f"{DATASETS_PREFIX}\xff"
|
||||
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
self.dataset_infos[dataset.identifier] = dataset
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: Dataset,
|
||||
) -> None:
|
||||
# Store in kvstore
|
||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=dataset_def.model_dump_json(),
|
||||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
||||
records = dataset_impl.df.to_dict("records")
|
||||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
import pandas
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
5
src/llama_stack/providers/inline/eval/__init__.py
Normal file
5
src/llama_stack/providers/inline/eval/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
impl = MetaReferenceEvalImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
deps[Api.agents],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="eval",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
259
src/llama_stack/providers/inline/eval/meta_reference/eval.py
Normal file
259
src/llama_stack/providers/inline/eval/meta_reference/eval.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.agents import Agents, StepType
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
||||
|
||||
class MetaReferenceEvalImpl(
|
||||
Eval,
|
||||
BenchmarksProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceEvalConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
agents_api: Agents,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
|
||||
self.benchmarks = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
# Load existing benchmarks from kvstore
|
||||
start_key = EVAL_TASKS_PREFIX
|
||||
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
||||
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for benchmark in stored_benchmarks:
|
||||
benchmark = Benchmark.model_validate_json(benchmark)
|
||||
self.benchmarks[benchmark.identifier] = benchmark
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
# Store in kvstore
|
||||
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=task_def.model_dump_json(),
|
||||
)
|
||||
self.benchmarks[task_def.identifier] = task_def
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
if benchmark_id in self.benchmarks:
|
||||
del self.benchmarks[benchmark_id]
|
||||
|
||||
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
scoring_functions = task_def.scoring_functions
|
||||
|
||||
# TODO (xiyan): validate dataset schema
|
||||
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
create_response = await self.agents_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
generations = []
|
||||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=input_messages,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||
final_event = turn_response[-1].event.payload
|
||||
|
||||
# check if there's a memory retrieval step and extract the context
|
||||
memory_rag_context = None
|
||||
for step in final_event.turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
for tool_response in step.tool_responses:
|
||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||
|
||||
agent_generation = {}
|
||||
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||
if memory_rag_context:
|
||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||
|
||||
generations.append(agent_generation)
|
||||
|
||||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
if candidate.sampling_params.stop:
|
||||
sampling_params["stop"] = candidate.sampling_params.stop
|
||||
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
params = OpenAICompletionRequestWithExtraBody(
|
||||
model=candidate.model,
|
||||
prompt=input_content,
|
||||
**sampling_params,
|
||||
)
|
||||
response = await self.inference_api.openai_completion(params)
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [
|
||||
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
|
||||
]
|
||||
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
|
||||
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += input_messages
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
**sampling_params,
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
return generations
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||
elif candidate.type == "model":
|
||||
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||
|
||||
# scoring with generated_answer
|
||||
score_input_rows = [
|
||||
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||
]
|
||||
|
||||
if benchmark_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = dict.fromkeys(scoring_functions)
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
if job_id in self.jobs:
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
return self.jobs[job_id]
|
||||
20
src/llama_stack/providers/inline/files/localfs/__init__.py
Normal file
20
src/llama_stack/providers/inline/files/localfs/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
from .files import LocalfsFilesImpl
|
||||
|
||||
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
impl = LocalfsFilesImpl(config, policy)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
31
src/llama_stack/providers/inline/files/localfs/config.py
Normal file
31
src/llama_stack/providers/inline/files/localfs/config.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
|
||||
|
||||
class LocalfsFilesImplConfig(BaseModel):
|
||||
storage_dir: str = Field(
|
||||
description="Directory to store uploaded files",
|
||||
)
|
||||
metadata_store: SqlStoreReference = Field(
|
||||
description="SQL store configuration for file metadata",
|
||||
)
|
||||
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
|
||||
"metadata_store": SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="files_metadata",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
219
src/llama_stack/providers/inline/files/localfs/files.py
Normal file
219
src/llama_stack/providers/inline/files/localfs/files.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="files")
|
||||
|
||||
|
||||
class LocalfsFilesImpl(Files):
|
||||
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self.config = config
|
||||
self.policy = policy
|
||||
self.sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the files provider by setting up storage directory and metadata database."""
|
||||
# Create storage directory if it doesn't exist
|
||||
storage_path = Path(self.config.storage_dir)
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
"file_path": ColumnType.STRING, # Path to actual file on disk
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
def _get_file_path(self, file_id: str) -> Path:
|
||||
"""Get the filesystem path for a file ID."""
|
||||
return Path(self.config.storage_dir) / file_id
|
||||
|
||||
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
|
||||
"""Look up a OpenAIFileObject and filesystem path from its ID."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
file_path = Path(row.pop("file_path"))
|
||||
return OpenAIFileObject(**row), file_path
|
||||
|
||||
# OpenAI Files API Implementation
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload a file that can be used across various endpoints."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after is not None:
|
||||
logger.warning(
|
||||
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
|
||||
)
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
created_at = int(time.time())
|
||||
expires_at = created_at + self.config.ttl_secs
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_files",
|
||||
{
|
||||
"id": file_id,
|
||||
"filename": file.filename or "uploaded_file",
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
"file_path": file_path.as_posix(),
|
||||
},
|
||||
)
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_id,
|
||||
filename=file.filename or "uploaded_file",
|
||||
purpose=purpose,
|
||||
bytes=file_size,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""Returns a list of files that belong to the user's organization."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [
|
||||
OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in paginated_result.data
|
||||
]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
"""Returns information about a specific file."""
|
||||
file_obj, _ = await self._lookup_file_id(file_id)
|
||||
|
||||
return file_obj
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
"""Delete a file."""
|
||||
# Delete physical file
|
||||
_, file_path = await self._lookup_file_id(file_id)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
# Delete metadata from database
|
||||
assert self.sql_store is not None, "Files provider not initialized"
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
return OpenAIFileDeleteResponse(
|
||||
id=file_id,
|
||||
deleted=True,
|
||||
)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
"""Returns the contents of the specified file."""
|
||||
# Read file content
|
||||
file_obj, file_path = await self._lookup_file_id(file_id)
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
|
||||
await self.openai_delete_file(file_id)
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
# Return as binary response with appropriate content type
|
||||
return Response(
|
||||
content=file_path.read_bytes(),
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
|
||||
)
|
||||
5
src/llama_stack/providers/inline/inference/__init__.py
Normal file
5
src/llama_stack/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
|
||||
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
||||
# this is a placeholder to indicate inference model id
|
||||
# the actual inference model id is dtermined by the moddel id in the request
|
||||
# Note: you need to register the model before using it for inference
|
||||
# models in the resouce list in the run.yaml config will be registered automatically
|
||||
model: str | None = None
|
||||
torch_seed: int | None = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
model_parallel_size: int | None = None
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
# (including our testing code) who might be using llama-stack as a library.
|
||||
create_distributed_process_group: bool = True
|
||||
|
||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: str | None = None
|
||||
|
||||
quantization: QuantizationConfig | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
|
||||
max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
|
||||
max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
"quantization": {
|
||||
"type": quantization_type,
|
||||
},
|
||||
"model_parallel_size": model_parallel_size,
|
||||
"max_batch_size": max_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
get_default_tool_prompt_format,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .inference import resolve_model
|
||||
|
||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: torch.Tensor | None = None
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: ResponseFormat | None,
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if not isinstance(response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.json_schema)
|
||||
data = TokenEnforcerTokenizerData(
|
||||
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||
tokenizer.decode,
|
||||
tokenizer.stop_tokens,
|
||||
)
|
||||
token_enforcer = TokenEnforcer(data, parser)
|
||||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature or 1.0
|
||||
top_p = sampling_params.strategy.top_p or 1.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
||||
|
||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||
tool_config = request.tool_config
|
||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
||||
return tool_config.tool_prompt_format
|
||||
else:
|
||||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
class LlamaGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
quantization_mode = None
|
||||
|
||||
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
|
||||
self.inner_generator = cls.build(
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = get_logger(__name__, category="inference")
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
self.model_id = None
|
||||
self.llama_model = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
llama_model = (
|
||||
resolve_model(model.metadata["llama_model"])
|
||||
if "llama_model" in model.metadata
|
||||
else resolve_model(model.identifier)
|
||||
)
|
||||
if llama_model is None:
|
||||
raise ValueError(
|
||||
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
||||
)
|
||||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_hf_repo_model_entry(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
# TODO: what is this?! you can't really specify skipping via model metadata
|
||||
# kill this madness
|
||||
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
||||
return model
|
||||
|
||||
await self.load_model(model.identifier, llama_model)
|
||||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
builder_fn=llama_builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||
if llama_model.model_family == ModelFamily.llama4
|
||||
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
||||
),
|
||||
)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = llama_builder_fn(*builder_params)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
await self.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
||||
max_tokens=20,
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, llama):
|
||||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: Any):
|
||||
if task[0] == "chat_completion":
|
||||
return self.llama.chat_completion(task[1])
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {task[0]}")
|
||||
|
||||
|
||||
def init_model_cb(
|
||||
builder_fn: Callable,
|
||||
params: list[Any],
|
||||
):
|
||||
llama = builder_fn(*params)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
class LlamaModelParallelGenerator:
|
||||
"""
|
||||
This abstraction exists so
|
||||
- we can run model parallel code without needing to run the CLIs via torchrun
|
||||
- this also enables use model parallel code within a notebook context.
|
||||
|
||||
A Context Manager is used to ensure that the model parallel process is started and stopped
|
||||
correctly. This does make the ergonomics a little awkward, because it isn't immediately
|
||||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_parallel_size: int,
|
||||
builder_fn: Callable,
|
||||
builder_params: list[Any],
|
||||
formatter: Llama3ChatFormat | Llama4ChatFormat,
|
||||
):
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.builder_fn = builder_fn
|
||||
self.builder_params = builder_params
|
||||
self.formatter = formatter
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
||||
def stop(self):
|
||||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
self.group = ModelParallelProcessGroup(
|
||||
self.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("completion", req_obj))
|
||||
yield from gen
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||
yield from gen
|
||||
|
|
@ -0,0 +1,363 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_group,
|
||||
get_model_parallel_rank,
|
||||
get_model_parallel_src_rank,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
ready_request = "ready_request"
|
||||
ready_response = "ready_response"
|
||||
end_sentinel = "end_sentinel"
|
||||
cancel_sentinel = "cancel_sentinel"
|
||||
task_request = "task_request"
|
||||
task_response = "task_response"
|
||||
exception_response = "exception_response"
|
||||
|
||||
|
||||
class ReadyRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
|
||||
|
||||
|
||||
class ReadyResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
|
||||
|
||||
|
||||
class EndSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
|
||||
|
||||
|
||||
class CancelSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: list[GenerationResult]
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
|
||||
error: str
|
||||
|
||||
|
||||
ProcessingMessage = (
|
||||
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
|
||||
)
|
||||
|
||||
|
||||
class ProcessingMessageWrapper(BaseModel):
|
||||
payload: Annotated[
|
||||
ProcessingMessage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def mp_rank_0() -> bool:
|
||||
return bool(get_model_parallel_rank() == 0)
|
||||
|
||||
|
||||
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
|
||||
|
||||
|
||||
def retrieve_requests(reply_socket_url: str):
|
||||
if mp_rank_0():
|
||||
context = zmq.Context()
|
||||
reply_socket = context.socket(zmq.ROUTER)
|
||||
reply_socket.connect(reply_socket_url)
|
||||
|
||||
while True:
|
||||
client_id, obj = maybe_get_work(reply_socket)
|
||||
if obj is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
ready_response = ReadyResponse()
|
||||
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
||||
break
|
||||
|
||||
def send_obj(obj: ProcessingMessage):
|
||||
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||
|
||||
while True:
|
||||
tasks: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
||||
if maybe_task_json is not None:
|
||||
task = maybe_parse_message(maybe_task_json)
|
||||
# there is still an unknown unclean GeneratorExit happening resulting in a
|
||||
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
||||
# kind of a hack this is :/
|
||||
if task is not None and not isinstance(task, CancelSentinel):
|
||||
tasks = [task]
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
tasks,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
|
||||
task = tasks[0]
|
||||
if task is None:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
try:
|
||||
out = yield task
|
||||
if out is None:
|
||||
break
|
||||
|
||||
for obj in out:
|
||||
updates: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
_, update_json = maybe_get_work(reply_socket)
|
||||
update = maybe_parse_message(update_json)
|
||||
if isinstance(update, CancelSentinel):
|
||||
updates = [update]
|
||||
else:
|
||||
# only send the update if it's not cancelled otherwise the object sits in the socket
|
||||
# and gets pulled in the next request lol
|
||||
send_obj(TaskResponse(result=obj))
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
updates,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
if isinstance(updates[0], CancelSentinel):
|
||||
log.info("quitting generation loop because request was cancelled")
|
||||
break
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(EndSentinel())
|
||||
except Exception as e:
|
||||
log.exception("exception in generation loop")
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(ExceptionResponse(error=str(e)))
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(EndSentinel())
|
||||
|
||||
|
||||
def maybe_get_work(sock: zmq.Socket):
|
||||
message = None
|
||||
client_id = None
|
||||
try:
|
||||
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
||||
message = obj.decode("utf-8")
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno != zmq.EAGAIN:
|
||||
raise e
|
||||
|
||||
return client_id, message
|
||||
|
||||
|
||||
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
|
||||
if maybe_json is None:
|
||||
return None
|
||||
try:
|
||||
return parse_message(maybe_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def parse_message(json_str: str) -> ProcessingMessage:
|
||||
data = json.loads(json_str)
|
||||
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
reply_socket_url: str,
|
||||
init_model_cb: Callable,
|
||||
) -> None:
|
||||
model = init_model_cb()
|
||||
torch.distributed.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# run the requests co-routine which retrieves requests from the socket
|
||||
# and sends responses (we provide) back to the caller
|
||||
req_gen = retrieve_requests(reply_socket_url)
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
task = req_gen.send(result)
|
||||
if isinstance(task, EndSentinel):
|
||||
break
|
||||
|
||||
assert isinstance(task, TaskRequest), task
|
||||
result = model(task.task)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
log.info("[debug] worker process done")
|
||||
|
||||
|
||||
def launch_dist_group(
|
||||
reply_socket_url: str,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||
launch_config = LaunchConfig(
|
||||
max_nodes=1,
|
||||
min_nodes=1,
|
||||
nproc_per_node=model_parallel_size,
|
||||
start_method="fork",
|
||||
rdzv_backend="c10d",
|
||||
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
||||
rdzv_configs={"store_type": "file", "timeout": 90},
|
||||
max_restarts=0,
|
||||
monitor_interval=1,
|
||||
run_id=str(uuid.uuid4()),
|
||||
)
|
||||
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
||||
reply_socket_url,
|
||||
init_model_cb,
|
||||
)
|
||||
|
||||
|
||||
def start_model_parallel_process(
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
context = zmq.Context()
|
||||
request_socket = context.socket(zmq.DEALER)
|
||||
|
||||
# Binding the request socket to a random port
|
||||
request_socket.bind("tcp://127.0.0.1:0")
|
||||
|
||||
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
process = ctx.Process(
|
||||
target=launch_dist_group,
|
||||
args=(
|
||||
main_process_url,
|
||||
model_parallel_size,
|
||||
init_model_cb,
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
process.start()
|
||||
|
||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
_response = request_socket.recv()
|
||||
log.info("Loaded model...")
|
||||
|
||||
return request_socket, process
|
||||
|
||||
|
||||
class ModelParallelProcessGroup:
|
||||
def __init__(
|
||||
self,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.init_model_cb = init_model_cb
|
||||
self.started = False
|
||||
self.running = False
|
||||
|
||||
def start(self):
|
||||
assert not self.started, "process group already started"
|
||||
self.request_socket, self.process = start_model_parallel_process(
|
||||
self.model_parallel_size,
|
||||
self.init_model_cb,
|
||||
)
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
assert self.started, "process group not started"
|
||||
if self.process.is_alive():
|
||||
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
||||
self.process.join()
|
||||
self.started = False
|
||||
|
||||
def run_inference(
|
||||
self,
|
||||
req: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
self.running = True
|
||||
try:
|
||||
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
||||
while True:
|
||||
obj_json = self.request_socket.recv()
|
||||
obj = parse_message(obj_json)
|
||||
|
||||
if isinstance(obj, EndSentinel):
|
||||
break
|
||||
|
||||
if isinstance(obj, ExceptionResponse):
|
||||
log.error(f"[debug] got exception {obj.error}")
|
||||
raise Exception(obj.error)
|
||||
|
||||
if isinstance(obj, TaskResponse):
|
||||
yield obj.result
|
||||
|
||||
except GeneratorExit:
|
||||
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||
while True:
|
||||
obj_json = self.request_socket.send()
|
||||
obj = parse_message(obj_json)
|
||||
if isinstance(obj, EndSentinel):
|
||||
break
|
||||
finally:
|
||||
self.running = False
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
impl = SentenceTransformersInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentenceTransformersInferenceConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="nomic-ai/nomic-embed-text-v1.5",
|
||||
provider_resource_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
||||
|
|
@ -0,0 +1,550 @@
|
|||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 56;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5CADC7192CA471CC007662D2 /* LlamaStackClient */; };
|
||||
5CAF3DD82CA485740029CD2B /* LlamaStackClient in Frameworks */ = {isa = PBXBuildFile; productRef = 5CAF3DD72CA485740029CD2B /* LlamaStackClient */; };
|
||||
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */ = {isa = PBXBuildFile; fileRef = 5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */; settings = {ATTRIBUTES = (Public, ); }; };
|
||||
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6742CA1F45800E958D0 /* executorch_debug */; };
|
||||
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; };
|
||||
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = 5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */; platformFilter = ios; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
|
||||
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */; };
|
||||
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */; };
|
||||
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */; };
|
||||
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */; };
|
||||
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */ = {isa = PBXBuildFile; productRef = 5CCBC6922CA1F7D000E958D0 /* Stencil */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
|
||||
proxyType = 2;
|
||||
remoteGlobalIDString = 036CAF9D2BB1444500D6C2D5;
|
||||
remoteInfo = LLaMA;
|
||||
};
|
||||
5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
|
||||
proxyType = 2;
|
||||
remoteGlobalIDString = 03729ED52BB1F8DE00152F2E;
|
||||
remoteInfo = LLaMARunner;
|
||||
};
|
||||
5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
|
||||
proxyType = 2;
|
||||
remoteGlobalIDString = 5CCBC6982CA2036A00E958D0;
|
||||
remoteInfo = LLaMAPerfBenchmark;
|
||||
};
|
||||
5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
|
||||
proxyType = 2;
|
||||
remoteGlobalIDString = 5CCBC6992CA2036A00E958D0;
|
||||
remoteInfo = LLaMAPerfBenchmarkTests;
|
||||
};
|
||||
/* End PBXContainerItemProxy section */
|
||||
|
||||
/* Begin PBXCopyFilesBuildPhase section */
|
||||
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */ = {
|
||||
isa = PBXCopyFilesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
dstPath = "";
|
||||
dstSubfolderSpec = 10;
|
||||
files = (
|
||||
5CCBC6872CA1F64A00E958D0 /* LLaMARunner.framework in Embed Frameworks */,
|
||||
);
|
||||
name = "Embed Frameworks";
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXCopyFilesBuildPhase section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LocalInferenceImpl.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LocalInference.h; sourceTree = "<group>"; };
|
||||
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = LLaMA.xcodeproj; path = "executorch/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj"; sourceTree = "<group>"; };
|
||||
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PromptTemplate.swift; sourceTree = "<group>"; };
|
||||
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LocalInference.swift; sourceTree = "<group>"; };
|
||||
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Parsing.swift; sourceTree = "<group>"; };
|
||||
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SystemPrompts.swift; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
5CCBC6052CA1F04A00E958D0 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
5CADC71A2CA471CC007662D2 /* LlamaStackClient in Frameworks */,
|
||||
5CAF3DD82CA485740029CD2B /* LlamaStackClient in Frameworks */,
|
||||
5CCBC6932CA1F7D000E958D0 /* Stencil in Frameworks */,
|
||||
5CCBC6862CA1F64A00E958D0 /* LLaMARunner.framework in Frameworks */,
|
||||
5CCBC6752CA1F45800E958D0 /* executorch_debug in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
5CCBC5FE2CA1F04A00E958D0 = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */,
|
||||
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */,
|
||||
5CCBC6092CA1F04A00E958D0 /* Products */,
|
||||
5CCBC6852CA1F64A00E958D0 /* Frameworks */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
5CCBC6092CA1F04A00E958D0 /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
5CCBC60A2CA1F04A00E958D0 /* LocalInferenceImpl */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
5CCBC68A2CA1F7A000E958D0 /* LocalInference.swift */,
|
||||
5CCBC68B2CA1F7A000E958D0 /* Parsing.swift */,
|
||||
5CCBC6892CA1F7A000E958D0 /* PromptTemplate.swift */,
|
||||
5CCBC68C2CA1F7A100E958D0 /* SystemPrompts.swift */,
|
||||
5CCBC60B2CA1F04A00E958D0 /* LocalInference.h */,
|
||||
);
|
||||
path = LocalInferenceImpl;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
5CCBC6772CA1F63F00E958D0 /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */,
|
||||
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */,
|
||||
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */,
|
||||
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
5CCBC6852CA1F64A00E958D0 /* Frameworks */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
);
|
||||
name = Frameworks;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXHeadersBuildPhase section */
|
||||
5CCBC6032CA1F04A00E958D0 /* Headers */ = {
|
||||
isa = PBXHeadersBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
5CCBC60C2CA1F04A00E958D0 /* LocalInference.h in Headers */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXHeadersBuildPhase section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = 5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */;
|
||||
buildPhases = (
|
||||
5CCBC6032CA1F04A00E958D0 /* Headers */,
|
||||
5CCBC6042CA1F04A00E958D0 /* Sources */,
|
||||
5CCBC6052CA1F04A00E958D0 /* Frameworks */,
|
||||
5CCBC6062CA1F04A00E958D0 /* Resources */,
|
||||
5CCBC6882CA1F64A00E958D0 /* Embed Frameworks */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
name = LocalInferenceImpl;
|
||||
packageProductDependencies = (
|
||||
5CCBC6742CA1F45800E958D0 /* executorch_debug */,
|
||||
5CCBC6922CA1F7D000E958D0 /* Stencil */,
|
||||
5CADC7192CA471CC007662D2 /* LlamaStackClient */,
|
||||
5CAF3DD72CA485740029CD2B /* LlamaStackClient */,
|
||||
);
|
||||
productName = LocalInferenceProvider;
|
||||
productReference = 5CCBC6082CA1F04A00E958D0 /* LocalInferenceImpl.framework */;
|
||||
productType = "com.apple.product-type.framework";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
5CCBC5FF2CA1F04A00E958D0 /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = 1;
|
||||
LastUpgradeCheck = 1540;
|
||||
TargetAttributes = {
|
||||
5CCBC6072CA1F04A00E958D0 = {
|
||||
CreatedOnToolsVersion = 15.4;
|
||||
LastSwiftMigration = 1540;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = 5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */;
|
||||
compatibilityVersion = "Xcode 14.0";
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = 5CCBC5FE2CA1F04A00E958D0;
|
||||
packageReferences = (
|
||||
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */,
|
||||
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */,
|
||||
5CAF3DD62CA485740029CD2B /* XCRemoteSwiftPackageReference "llama-stack-client-swift" */,
|
||||
);
|
||||
productRefGroup = 5CCBC6092CA1F04A00E958D0 /* Products */;
|
||||
projectDirPath = "";
|
||||
projectReferences = (
|
||||
{
|
||||
ProductGroup = 5CCBC6772CA1F63F00E958D0 /* Products */;
|
||||
ProjectRef = 5CCBC6762CA1F63F00E958D0 /* LLaMA.xcodeproj */;
|
||||
},
|
||||
);
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
5CCBC6072CA1F04A00E958D0 /* LocalInferenceImpl */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXReferenceProxy section */
|
||||
5CCBC67E2CA1F63F00E958D0 /* LLaMA.app */ = {
|
||||
isa = PBXReferenceProxy;
|
||||
fileType = wrapper.application;
|
||||
path = LLaMA.app;
|
||||
remoteRef = 5CCBC67D2CA1F63F00E958D0 /* PBXContainerItemProxy */;
|
||||
sourceTree = BUILT_PRODUCTS_DIR;
|
||||
};
|
||||
5CCBC6802CA1F63F00E958D0 /* LLaMARunner.framework */ = {
|
||||
isa = PBXReferenceProxy;
|
||||
fileType = wrapper.framework;
|
||||
path = LLaMARunner.framework;
|
||||
remoteRef = 5CCBC67F2CA1F63F00E958D0 /* PBXContainerItemProxy */;
|
||||
sourceTree = BUILT_PRODUCTS_DIR;
|
||||
};
|
||||
5CCBC69F2CA2036B00E958D0 /* LLaMAPerfBenchmark.app */ = {
|
||||
isa = PBXReferenceProxy;
|
||||
fileType = wrapper.application;
|
||||
path = LLaMAPerfBenchmark.app;
|
||||
remoteRef = 5CCBC69E2CA2036B00E958D0 /* PBXContainerItemProxy */;
|
||||
sourceTree = BUILT_PRODUCTS_DIR;
|
||||
};
|
||||
5CCBC6A12CA2036B00E958D0 /* LLaMAPerfBenchmarkTests.xctest */ = {
|
||||
isa = PBXReferenceProxy;
|
||||
fileType = wrapper.cfbundle;
|
||||
path = LLaMAPerfBenchmarkTests.xctest;
|
||||
remoteRef = 5CCBC6A02CA2036B00E958D0 /* PBXContainerItemProxy */;
|
||||
sourceTree = BUILT_PRODUCTS_DIR;
|
||||
};
|
||||
/* End PBXReferenceProxy section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
5CCBC6062CA1F04A00E958D0 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
5CCBC6042CA1F04A00E958D0 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
5CCBC6902CA1F7A100E958D0 /* SystemPrompts.swift in Sources */,
|
||||
5CCBC68D2CA1F7A100E958D0 /* PromptTemplate.swift in Sources */,
|
||||
5CCBC68F2CA1F7A100E958D0 /* Parsing.swift in Sources */,
|
||||
5CCBC68E2CA1F7A100E958D0 /* LocalInference.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
5CCBC60D2CA1F04A00E958D0 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
VERSIONING_SYSTEM = "apple-generic";
|
||||
VERSION_INFO_PREFIX = "";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
5CCBC60E2CA1F04A00E958D0 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
VALIDATE_PRODUCT = YES;
|
||||
VERSIONING_SYSTEM = "apple-generic";
|
||||
VERSION_INFO_PREFIX = "";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
5CCBC6102CA1F04A00E958D0 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEFINES_MODULE = YES;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
ENABLE_MODULE_VERIFIER = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
HEADER_SEARCH_PATHS = "";
|
||||
INFOPLIST_KEY_NSHumanReadableCopyright = "";
|
||||
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
"@loader_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
|
||||
OTHER_LDFLAGS = "";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
|
||||
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
|
||||
SKIP_INSTALL = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_INSTALL_OBJC_HEADER = NO;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
5CCBC6112CA1F04A00E958D0 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUILD_LIBRARY_FOR_DISTRIBUTION = YES;
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEFINES_MODULE = YES;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
ENABLE_MODULE_VERIFIER = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
HEADER_SEARCH_PATHS = "";
|
||||
INFOPLIST_KEY_NSHumanReadableCopyright = "";
|
||||
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
"@loader_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20";
|
||||
OTHER_LDFLAGS = "";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = meta.llamatsack.LocalInferenceProvider;
|
||||
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
|
||||
SKIP_INSTALL = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_INSTALL_OBJC_HEADER = NO;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
5CCBC6022CA1F04A00E958D0 /* Build configuration list for PBXProject "LocalInferenceImpl" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
5CCBC60D2CA1F04A00E958D0 /* Debug */,
|
||||
5CCBC60E2CA1F04A00E958D0 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
5CCBC60F2CA1F04A00E958D0 /* Build configuration list for PBXNativeTarget "LocalInferenceImpl" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
5CCBC6102CA1F04A00E958D0 /* Debug */,
|
||||
5CCBC6112CA1F04A00E958D0 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
|
||||
/* Begin XCRemoteSwiftPackageReference section */
|
||||
5CAF3DD62CA485740029CD2B /* XCRemoteSwiftPackageReference "llama-stack-client-swift" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/meta-llama/llama-stack-client-swift";
|
||||
requirement = {
|
||||
branch = main;
|
||||
kind = branch;
|
||||
};
|
||||
};
|
||||
5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/pytorch/executorch";
|
||||
requirement = {
|
||||
branch = latest;
|
||||
kind = branch;
|
||||
};
|
||||
};
|
||||
5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/stencilproject/Stencil";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 0.15.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
/* Begin XCSwiftPackageProductDependency section */
|
||||
5CADC7192CA471CC007662D2 /* LlamaStackClient */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
productName = LlamaStackClient;
|
||||
};
|
||||
5CAF3DD72CA485740029CD2B /* LlamaStackClient */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = 5CAF3DD62CA485740029CD2B /* XCRemoteSwiftPackageReference "llama-stack-client-swift" */;
|
||||
productName = LlamaStackClient;
|
||||
};
|
||||
5CCBC6742CA1F45800E958D0 /* executorch_debug */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = 5CCBC6732CA1F45800E958D0 /* XCRemoteSwiftPackageReference "executorch" */;
|
||||
productName = executorch_debug;
|
||||
};
|
||||
5CCBC6922CA1F7D000E958D0 /* Stencil */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = 5CCBC6912CA1F7D000E958D0 /* XCRemoteSwiftPackageReference "Stencil" */;
|
||||
productName = Stencil;
|
||||
};
|
||||
/* End XCSwiftPackageProductDependency section */
|
||||
};
|
||||
rootObject = 5CCBC5FF2CA1F04A00E958D0 /* Project object */;
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Workspace
|
||||
version = "1.0">
|
||||
<FileRef
|
||||
location = "self:">
|
||||
</FileRef>
|
||||
</Workspace>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>IDEDidComputeMac32BitWarning</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
|
||||
//! Project version number for LocalInference.
|
||||
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
|
||||
|
||||
//! Project version string for LocalInference.
|
||||
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
|
||||
|
||||
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
import Foundation
|
||||
|
||||
import LLaMARunner
|
||||
import LlamaStackClient
|
||||
|
||||
class RunnerHolder: ObservableObject {
|
||||
var runner: Runner?
|
||||
}
|
||||
|
||||
public class LocalInference: Inference {
|
||||
private var runnerHolder = RunnerHolder()
|
||||
private let runnerQueue: DispatchQueue
|
||||
|
||||
public init (queue: DispatchQueue) {
|
||||
runnerQueue = queue
|
||||
}
|
||||
|
||||
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
|
||||
runnerHolder.runner = runnerHolder.runner ?? Runner(
|
||||
modelPath: modelPath,
|
||||
tokenizerPath: tokenizerPath
|
||||
)
|
||||
|
||||
|
||||
runnerQueue.async {
|
||||
let runner = self.runnerHolder.runner
|
||||
do {
|
||||
try runner!.load()
|
||||
completion(.success(()))
|
||||
} catch let loadError {
|
||||
print("error: " + loadError.localizedDescription)
|
||||
completion(.failure(loadError))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func stop() {
|
||||
runnerHolder.runner?.stop()
|
||||
}
|
||||
|
||||
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
|
||||
return AsyncStream { continuation in
|
||||
let workItem = DispatchWorkItem {
|
||||
do {
|
||||
var tokens: [String] = []
|
||||
|
||||
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
|
||||
var stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload? = nil
|
||||
var buffer = ""
|
||||
var ipython = false
|
||||
var echoDropped = false
|
||||
|
||||
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
|
||||
buffer += token
|
||||
|
||||
// HACK: Workaround until LlamaRunner exposes echo param
|
||||
if (!echoDropped) {
|
||||
if (buffer.hasPrefix(prompt)) {
|
||||
buffer = String(buffer.dropFirst(prompt.count))
|
||||
echoDropped = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tokens.append(token)
|
||||
|
||||
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
|
||||
ipython = true
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
event_type: .progress,
|
||||
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
||||
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
||||
tool_call: .case1(""),
|
||||
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.started
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if (buffer.starts(with: "<|python_tag|>")) {
|
||||
buffer = String(buffer.dropFirst("<|python_tag|>".count))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Non-streaming lobprobs
|
||||
|
||||
var text = ""
|
||||
if token == "<|eot_id|>" {
|
||||
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_turn
|
||||
} else if token == "<|eom_id|>" {
|
||||
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
|
||||
} else {
|
||||
text = token
|
||||
}
|
||||
|
||||
var delta: Components.Schemas.ContentDelta
|
||||
if ipython {
|
||||
delta = .tool_call(Components.Schemas.ToolCallDelta(
|
||||
_type: .tool_call,
|
||||
tool_call: .case1(text),
|
||||
parse_status: .in_progress
|
||||
))
|
||||
} else {
|
||||
delta = .text(Components.Schemas.TextDelta(
|
||||
_type: Components.Schemas.TextDelta._typePayload.text,
|
||||
text: text
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if stopReason == nil {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
event_type: .progress,
|
||||
delta: delta
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if stopReason == nil {
|
||||
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.out_of_tokens
|
||||
}
|
||||
|
||||
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
|
||||
// TODO: non-streaming support
|
||||
|
||||
let didParseToolCalls = message.tool_calls?.count ?? 0 > 0
|
||||
if ipython && !didParseToolCalls {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
event_type: .progress,
|
||||
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
||||
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
||||
tool_call: .case1(""),
|
||||
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.failed
|
||||
)
|
||||
)
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
for toolCall in message.tool_calls! {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
event_type: .progress,
|
||||
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
||||
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
||||
tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall),
|
||||
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.succeeded
|
||||
)
|
||||
)
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
event_type: .complete,
|
||||
delta: .text(Components.Schemas.TextDelta(
|
||||
_type: Components.Schemas.TextDelta._typePayload.text,
|
||||
text: ""
|
||||
)
|
||||
)
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
catch (let error) {
|
||||
print("Inference error: " + error.localizedDescription)
|
||||
}
|
||||
}
|
||||
runnerQueue.async(execute: workItem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,238 @@
|
|||
import Foundation
|
||||
|
||||
import LlamaStackClient
|
||||
|
||||
func encodeHeader(role: String) -> String {
|
||||
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
|
||||
}
|
||||
|
||||
func encodeDialogPrompt(messages: [Components.Schemas.Message]) -> String {
|
||||
var prompt = ""
|
||||
|
||||
prompt.append("<|begin_of_text|>")
|
||||
for message in messages {
|
||||
let msg = encodeMessage(message: message)
|
||||
prompt += msg
|
||||
}
|
||||
|
||||
prompt.append(encodeHeader(role: "assistant"))
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
func getRole(message: Components.Schemas.Message) -> String {
|
||||
switch (message) {
|
||||
case .user(let m):
|
||||
return m.role.rawValue
|
||||
case .system(let m):
|
||||
return m.role.rawValue
|
||||
case .tool(let m):
|
||||
return m.role.rawValue
|
||||
case .assistant(let m):
|
||||
return m.role.rawValue
|
||||
}
|
||||
}
|
||||
|
||||
func encodeMessage(message: Components.Schemas.Message) -> String {
|
||||
var prompt = encodeHeader(role: getRole(message: message))
|
||||
|
||||
switch (message) {
|
||||
case .assistant(let m):
|
||||
if (m.tool_calls?.count ?? 0 > 0) {
|
||||
prompt += "<|python_tag|>"
|
||||
}
|
||||
default:0
|
||||
break
|
||||
}
|
||||
|
||||
func _processContent(_ content: Any) -> String {
|
||||
func _process(_ c: Any) {
|
||||
if let str = c as? String {
|
||||
prompt += str
|
||||
}
|
||||
}
|
||||
|
||||
if let str = content as? String {
|
||||
_process(str)
|
||||
} else if let list = content as? [Any] {
|
||||
for c in list {
|
||||
_process(c)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
switch (message) {
|
||||
case .user(let m):
|
||||
prompt += _processContent(m.content)
|
||||
case .system(let m):
|
||||
prompt += _processContent(m.content)
|
||||
case .tool(let m):
|
||||
prompt += _processContent(m.content)
|
||||
case .assistant(let m):
|
||||
prompt += _processContent(m.content)
|
||||
}
|
||||
|
||||
var eom = false
|
||||
|
||||
switch (message) {
|
||||
case .user(let m):
|
||||
switch (m.content) {
|
||||
case .case1(let c):
|
||||
prompt += _processContent(c)
|
||||
case .InterleavedContentItem(let c):
|
||||
prompt += _processContent(c)
|
||||
case .case3(let c):
|
||||
prompt += _processContent(c)
|
||||
}
|
||||
case .assistant(let m):
|
||||
// TODO: Support encoding past tool call history
|
||||
// for t in m.tool_calls {
|
||||
// _processContent(t.)
|
||||
//}
|
||||
eom = m.stop_reason == Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
|
||||
case .system(_):
|
||||
break
|
||||
case .tool(_):
|
||||
break
|
||||
}
|
||||
|
||||
if (eom) {
|
||||
prompt += "<|eom_id|>"
|
||||
} else {
|
||||
prompt += "<|eot_id|>"
|
||||
}
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.Message] {
|
||||
var existingMessages = request.messages
|
||||
var existingSystemMessage: Components.Schemas.Message?
|
||||
// TODO: Existing system message
|
||||
|
||||
var messages: [Components.Schemas.Message] = []
|
||||
|
||||
let defaultGen = SystemDefaultGenerator()
|
||||
let defaultTemplate = defaultGen.gen()
|
||||
|
||||
var sysContent = ""
|
||||
|
||||
// TODO: Built-in tools
|
||||
|
||||
sysContent += try defaultTemplate.render()
|
||||
|
||||
messages.append(.system(Components.Schemas.SystemMessage(
|
||||
role: .system,
|
||||
content: .case1(sysContent)
|
||||
))
|
||||
)
|
||||
|
||||
if request.tools?.isEmpty == false {
|
||||
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
|
||||
let toolGen = FunctionTagCustomToolGenerator()
|
||||
let toolTemplate = try toolGen.gen(customTools: request.tools!)
|
||||
let tools = try toolTemplate.render()
|
||||
messages.append(.user(Components.Schemas.UserMessage(
|
||||
role: .user,
|
||||
content: .case1(tools))
|
||||
))
|
||||
}
|
||||
|
||||
messages.append(contentsOf: existingMessages)
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
struct FunctionCall {
|
||||
let name: String
|
||||
let params: [String: Any]
|
||||
}
|
||||
|
||||
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
|
||||
guard input.hasPrefix("[") && input.hasSuffix("]") else {
|
||||
return []
|
||||
}
|
||||
|
||||
do {
|
||||
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
|
||||
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
|
||||
|
||||
var result: [Components.Schemas.ToolCall] = []
|
||||
|
||||
for call in calls {
|
||||
guard let nameEndIndex = call.firstIndex(of: "("),
|
||||
let paramsStartIndex = call.firstIndex(of: "{"),
|
||||
let paramsEndIndex = call.lastIndex(of: "}") else {
|
||||
return []
|
||||
}
|
||||
|
||||
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
|
||||
|
||||
guard let data = paramsString.data(using: .utf8),
|
||||
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
|
||||
return []
|
||||
}
|
||||
|
||||
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
|
||||
for (param_name, param) in params {
|
||||
switch (param) {
|
||||
case let value as String:
|
||||
props[param_name] = .case1(value)
|
||||
case let value as Int:
|
||||
props[param_name] = .case2(value)
|
||||
case let value as Double:
|
||||
props[param_name] = .case3(value)
|
||||
case let value as Bool:
|
||||
props[param_name] = .case4(value)
|
||||
default:
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
result.append(
|
||||
Components.Schemas.ToolCall(
|
||||
call_id: UUID().uuidString,
|
||||
tool_name: .case2(name), // custom_tool
|
||||
arguments: .init(additionalProperties: props)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return result.isEmpty ? [] : result
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload) -> Components.Schemas.CompletionMessage {
|
||||
var content = tokens
|
||||
|
||||
let roles = ["user", "system", "assistant"]
|
||||
for role in roles {
|
||||
let headerStr = encodeHeader(role: role)
|
||||
if content.hasPrefix(headerStr) {
|
||||
content = String(content.dropFirst(encodeHeader(role: role).count))
|
||||
}
|
||||
}
|
||||
|
||||
if content.hasPrefix("<|python_tag|>") {
|
||||
content = String(content.dropFirst("<|python_tag|>".count))
|
||||
}
|
||||
|
||||
|
||||
if content.hasSuffix("<|eot_id|>") {
|
||||
content = String(content.dropLast("<|eot_id|>".count))
|
||||
} else {
|
||||
content = String(content.dropLast("<|eom_id|>".count))
|
||||
}
|
||||
|
||||
return Components.Schemas.CompletionMessage(
|
||||
role: .assistant,
|
||||
content: .case1(content),
|
||||
stop_reason: stopReason,
|
||||
tool_calls: maybeExtractCustomToolCalls(input: content)
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
import Foundation
|
||||
import Stencil
|
||||
|
||||
public struct PromptTemplate {
|
||||
let template: String
|
||||
let data: [String: Any]
|
||||
|
||||
public func render() throws -> String {
|
||||
let template = Template(templateString: self.template)
|
||||
return try template.render(self.data)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
import Foundation
|
||||
|
||||
import LlamaStackClient
|
||||
|
||||
func convertToNativeSwiftType(_ value: Any) -> Any {
|
||||
switch value {
|
||||
case let number as NSNumber:
|
||||
if CFGetTypeID(number) == CFBooleanGetTypeID() {
|
||||
return number.boolValue
|
||||
}
|
||||
if floor(number.doubleValue) == number.doubleValue {
|
||||
return number.intValue
|
||||
}
|
||||
return number.doubleValue
|
||||
case let string as String:
|
||||
return string
|
||||
case let array as [Any]:
|
||||
return array.map(convertToNativeSwiftType)
|
||||
case let dict as [String: Any]:
|
||||
return dict.mapValues(convertToNativeSwiftType)
|
||||
case is NSNull:
|
||||
return NSNull()
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
public class SystemDefaultGenerator {
|
||||
public init() {}
|
||||
|
||||
public func gen() -> PromptTemplate {
|
||||
let templateStr = """
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: {{ today }}
|
||||
"""
|
||||
|
||||
let dateFormatter = DateFormatter()
|
||||
dateFormatter.dateFormat = "dd MMMM yyyy"
|
||||
|
||||
return PromptTemplate(
|
||||
template: templateStr,
|
||||
data: ["today": dateFormatter.string(from: Date())]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public class FunctionTagCustomToolGenerator {
|
||||
public init() {}
|
||||
|
||||
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
|
||||
// TODO: required params
|
||||
// TODO: {{#unless @last}},{{/unless}}
|
||||
|
||||
let templateStr = """
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{% for t in custom_tools %}
|
||||
{
|
||||
"name": "{{t.tool_name}}",
|
||||
"description": "{{t.description}}",
|
||||
"input_schema": { {{t.input_schema}} }
|
||||
}
|
||||
|
||||
{{/let}}
|
||||
{% endfor -%}
|
||||
]
|
||||
"""
|
||||
|
||||
let encoder = JSONEncoder()
|
||||
return PromptTemplate(
|
||||
template: templateStr,
|
||||
data: ["custom_tools": try customTools.map {
|
||||
let data = try encoder.encode($0)
|
||||
let obj = try JSONSerialization.jsonObject(with: data)
|
||||
return convertToNativeSwiftType(obj)
|
||||
}]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
|
||||
|
||||
def evacuate_model_from_device(model, device: str):
|
||||
"""Safely clear a model from memory and free device resources.
|
||||
This function handles the proper cleanup of a model by:
|
||||
1. Moving the model to CPU if it's on a non-CPU device
|
||||
2. Deleting the model object to free memory
|
||||
3. Running garbage collection
|
||||
4. Clearing CUDA cache if the model was on a CUDA device
|
||||
Args:
|
||||
model: The PyTorch model to clear
|
||||
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
|
||||
Note:
|
||||
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
|
||||
- For MPS devices, only moves the model to CPU (no cache clearing available)
|
||||
- For CPU devices, only deletes the model object and runs garbage collection
|
||||
"""
|
||||
if device != "cpu":
|
||||
model.to("cpu")
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
if device == "cuda":
|
||||
# we need to import such that this is only imported when the method is called
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
DialogType,
|
||||
StringType,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
)
|
||||
|
||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||
"instruct": [
|
||||
{
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
}
|
||||
],
|
||||
"dialog": [
|
||||
{
|
||||
ColumnName.dialog.value: DialogType(),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
# post_training api and the huggingface provider is still experimental and under heavy development
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: HuggingFacePostTrainingConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import HuggingFacePostTrainingImpl
|
||||
|
||||
impl = HuggingFacePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HuggingFacePostTrainingConfig(BaseModel):
|
||||
# Device to run training on (cuda, cpu, mps)
|
||||
device: str = "cuda"
|
||||
|
||||
# Distributed training backend if using multiple devices
|
||||
# fsdp: Fully Sharded Data Parallel
|
||||
# deepspeed: DeepSpeed ZeRO optimization
|
||||
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
|
||||
|
||||
# Format for saving model checkpoints
|
||||
# full_state: Save complete model state
|
||||
# huggingface: Save in HuggingFace format (recommended for compatibility)
|
||||
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
|
||||
|
||||
# Template for formatting chat inputs and outputs
|
||||
# Used to structure the conversation format for training
|
||||
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
|
||||
|
||||
# Model-specific configuration parameters
|
||||
# trust_remote_code: Allow execution of custom model code
|
||||
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
|
||||
model_specific_config: dict = {
|
||||
"trust_remote_code": True,
|
||||
"attn_implementation": "sdpa",
|
||||
}
|
||||
|
||||
# Maximum sequence length for training
|
||||
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
|
||||
# Longer sequences may cause memory issues on MPS devices
|
||||
max_seq_length: int = 2048
|
||||
|
||||
# Enable gradient checkpointing to reduce memory usage
|
||||
# Trades computation for memory by recomputing activations
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
# Maximum number of checkpoints to keep
|
||||
# Older checkpoints are deleted when this limit is reached
|
||||
save_total_limit: int = 3
|
||||
|
||||
# Number of training steps between logging updates
|
||||
logging_steps: int = 10
|
||||
|
||||
# Ratio of training steps used for learning rate warmup
|
||||
# Helps stabilize early training
|
||||
warmup_ratio: float = 0.1
|
||||
|
||||
# L2 regularization coefficient
|
||||
# Helps prevent overfitting
|
||||
weight_decay: float = 0.01
|
||||
|
||||
# Number of worker processes for data loading
|
||||
# Higher values can improve data loading speed but increase memory usage
|
||||
dataloader_num_workers: int = 4
|
||||
|
||||
# Whether to pin memory in data loader
|
||||
# Can improve data transfer speed to GPU but uses more memory
|
||||
dataloader_pin_memory: bool = True
|
||||
|
||||
# DPO-specific parameters
|
||||
dpo_beta: float = 0.1
|
||||
use_reference_model: bool = True
|
||||
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
||||
dpo_output_dir: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "huggingface",
|
||||
"distributed_backend": None,
|
||||
"device": "cpu",
|
||||
"dpo_output_dir": __distro_dir__ + "/dpo_output",
|
||||
}
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
RESOURCES_STATS = "resources_stats"
|
||||
|
||||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
_JOB_TYPE_DPO_TRAINING = "dpo-training"
|
||||
|
||||
|
||||
class HuggingFacePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
config: HuggingFacePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
self._scheduler = Scheduler()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
metadata=resources_stats,
|
||||
)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
job_uuid=job_uuid,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=model,
|
||||
output_dir=checkpoint_dir,
|
||||
job_uuid=job_uuid,
|
||||
lora_config=algorithm_config,
|
||||
config=training_config,
|
||||
provider_config=self.config,
|
||||
)
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
if checkpoints:
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("HF finetuning completed")
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
||||
HFDPOAlignmentSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF DPO alignment")
|
||||
|
||||
recipe = HFDPOAlignmentSingleDevice(
|
||||
job_uuid=job_uuid,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=finetuned_model,
|
||||
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
|
||||
job_uuid=job_uuid,
|
||||
dpo_config=algorithm_config,
|
||||
config=training_config,
|
||||
provider_config=self.config,
|
||||
)
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
if checkpoints:
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
else:
|
||||
on_log_message_cb("Warning: No checkpoints were saved during DPO training")
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("HF DPO alignment completed")
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,519 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
calculate_training_steps,
|
||||
create_checkpoints,
|
||||
get_memory_stats,
|
||||
get_save_strategy,
|
||||
load_model,
|
||||
load_rows_from_dataset,
|
||||
setup_environment,
|
||||
setup_signal_handlers,
|
||||
setup_torch_device,
|
||||
split_dataset,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class HFFinetuningSingleDevice:
|
||||
def __init__(
|
||||
self,
|
||||
job_uuid: str,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
):
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.job_uuid = job_uuid
|
||||
|
||||
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
||||
"""Validate that the dataset has the required fields."""
|
||||
required_fields = ["input_query", "expected_answer", "chat_completion_input"]
|
||||
return all(field in row for row in rows for field in required_fields)
|
||||
|
||||
def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]:
|
||||
"""Process a row in instruct format."""
|
||||
if "chat_completion_input" in row and "expected_answer" in row:
|
||||
try:
|
||||
messages = json.loads(row["chat_completion_input"])
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
if "content" not in messages[0]:
|
||||
logger.warning(f"Message missing content: {messages[0]}")
|
||||
return None, None
|
||||
return messages[0]["content"], row["expected_answer"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]:
|
||||
"""Process a row in dialog format."""
|
||||
if "dialog" in row:
|
||||
try:
|
||||
dialog = json.loads(row["dialog"])
|
||||
if not isinstance(dialog, list) or len(dialog) < 2:
|
||||
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
return None, None
|
||||
if dialog[0].get("role") != "user":
|
||||
logger.warning(f"First message must be from user: {dialog[0]}")
|
||||
return None, None
|
||||
if not any(msg.get("role") == "assistant" for msg in dialog):
|
||||
logger.warning("Dialog must have at least one assistant message")
|
||||
return None, None
|
||||
|
||||
# Convert to human/gpt format
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
conversations = []
|
||||
for msg in dialog:
|
||||
if "role" not in msg or "content" not in msg:
|
||||
logger.warning(f"Message missing role or content: {msg}")
|
||||
continue
|
||||
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
|
||||
|
||||
# Format as a single conversation
|
||||
return conversations[0]["value"], conversations[1]["value"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]:
|
||||
"""Process a row using fallback formats."""
|
||||
if "input" in row and "output" in row:
|
||||
return row["input"], row["output"]
|
||||
elif "prompt" in row and "completion" in row:
|
||||
return row["prompt"], row["completion"]
|
||||
elif "question" in row and "answer" in row:
|
||||
return row["question"], row["answer"]
|
||||
return None, None
|
||||
|
||||
def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str:
|
||||
"""Format input and output text based on model requirements."""
|
||||
if hasattr(provider_config, "chat_template"):
|
||||
return provider_config.chat_template.format(input=input_text, output=output_text)
|
||||
return f"{input_text}\n{output_text}"
|
||||
|
||||
def _create_dataset(
|
||||
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
|
||||
) -> Dataset:
|
||||
"""Create and preprocess the dataset."""
|
||||
formatted_rows = []
|
||||
for row in rows:
|
||||
input_text = None
|
||||
output_text = None
|
||||
|
||||
# Process based on format
|
||||
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
if config.data_config.data_format.value == "instruct":
|
||||
input_text, output_text = self._process_instruct_format(row)
|
||||
elif config.data_config.data_format.value == "dialog":
|
||||
input_text, output_text = self._process_dialog_format(row)
|
||||
else:
|
||||
input_text, output_text = self._process_fallback_format(row)
|
||||
|
||||
if input_text and output_text:
|
||||
formatted_text = self._format_text(input_text, output_text, provider_config)
|
||||
formatted_rows.append({"text": formatted_text})
|
||||
|
||||
if not formatted_rows:
|
||||
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
raise ValueError(
|
||||
f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}"
|
||||
)
|
||||
|
||||
return Dataset.from_list(formatted_rows)
|
||||
|
||||
def _preprocess_dataset(
|
||||
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
|
||||
) -> Dataset:
|
||||
"""Preprocess the dataset with tokenizer."""
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=provider_config.max_seq_length,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
return ds.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=ds.column_names,
|
||||
)
|
||||
|
||||
def _run_training_sync(
|
||||
self,
|
||||
model: str,
|
||||
provider_config: dict[str, Any],
|
||||
peft_config: LoraConfig | None,
|
||||
config: dict[str, Any],
|
||||
output_dir_path: Path | None,
|
||||
) -> None:
|
||||
"""Synchronous wrapper for running training process.
|
||||
This method serves as a bridge between the multiprocessing Process and the async training function.
|
||||
It creates a new event loop to run the async training process.
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
dataset_id: ID of the dataset to use for training
|
||||
provider_config: Configuration specific to the HuggingFace provider
|
||||
peft_config: Optional LoRA configuration
|
||||
config: General training configuration
|
||||
output_dir_path: Optional path to save the model
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
logger.info("Starting training process with async wrapper")
|
||||
asyncio.run(
|
||||
self._run_training(
|
||||
model=model,
|
||||
provider_config=provider_config,
|
||||
peft_config=peft_config,
|
||||
config=config,
|
||||
output_dir_path=output_dir_path,
|
||||
)
|
||||
)
|
||||
|
||||
async def load_dataset(
|
||||
self,
|
||||
model: str,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
||||
"""Load and prepare the dataset for training.
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
config: Training configuration
|
||||
provider_config: Provider-specific configuration
|
||||
Returns:
|
||||
tuple: (train_dataset, eval_dataset, tokenizer)
|
||||
"""
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
logger.info(f"Initializing tokenizer for model: {model}")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||
|
||||
# Set pad token to eos token if not present
|
||||
# This is common for models that don't have a dedicated pad token
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Set padding side to right for causal language modeling
|
||||
# This ensures that padding tokens don't interfere with the model's ability
|
||||
# to predict the next token in the sequence
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
# Set truncation side to right to keep the beginning of the sequence
|
||||
# This is important for maintaining context and instruction format
|
||||
tokenizer.truncation_side = "right"
|
||||
|
||||
# Set model max length to match provider config
|
||||
# This ensures consistent sequence lengths across the training process
|
||||
tokenizer.model_max_length = provider_config.max_seq_length
|
||||
|
||||
logger.info("Tokenizer initialized successfully")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||
|
||||
# Create and preprocess dataset
|
||||
logger.info("Creating and preprocessing dataset")
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
logger.info(f"Dataset created with {len(ds)} examples")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||
|
||||
# Split dataset
|
||||
train_dataset, eval_dataset = split_dataset(ds)
|
||||
|
||||
return train_dataset, eval_dataset, tokenizer
|
||||
|
||||
def setup_training_args(
|
||||
self,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
device: torch.device,
|
||||
output_dir_path: Path | None,
|
||||
steps_per_epoch: int,
|
||||
) -> SFTConfig:
|
||||
"""Setup training arguments.
|
||||
Args:
|
||||
config: Training configuration
|
||||
provider_config: Provider-specific configuration
|
||||
device: The device to train on
|
||||
output_dir_path: Optional path to save the model
|
||||
steps_per_epoch: Number of steps per epoch
|
||||
Returns:
|
||||
Configured SFTConfig object
|
||||
"""
|
||||
logger.info("Configuring training arguments")
|
||||
lr = 2e-5
|
||||
if config.optimizer_config:
|
||||
lr = config.optimizer_config.lr
|
||||
logger.info(f"Using custom learning rate: {lr}")
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
data_config = config.data_config
|
||||
|
||||
# Calculate steps and get save strategy
|
||||
step_info = calculate_training_steps(steps_per_epoch, config)
|
||||
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
|
||||
|
||||
return SFTConfig(
|
||||
max_steps=step_info["max_steps"],
|
||||
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
||||
num_train_epochs=config.n_epochs,
|
||||
per_device_train_batch_size=data_config.batch_size,
|
||||
fp16=device.type == "cuda",
|
||||
bf16=False, # Causes CPU issues.
|
||||
eval_strategy=eval_strategy,
|
||||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||
save_strategy=save_strategy,
|
||||
report_to="none",
|
||||
max_length=provider_config.max_seq_length,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||
learning_rate=lr,
|
||||
warmup_ratio=provider_config.warmup_ratio,
|
||||
weight_decay=provider_config.weight_decay,
|
||||
remove_unused_columns=False,
|
||||
dataloader_pin_memory=provider_config.dataloader_pin_memory,
|
||||
dataloader_num_workers=provider_config.dataloader_num_workers,
|
||||
dataset_text_field="text",
|
||||
packing=False,
|
||||
load_best_model_at_end=True if output_dir_path else False,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
logging_steps=step_info["logging_steps"],
|
||||
)
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
model_obj: AutoModelForCausalLM,
|
||||
trainer: SFTTrainer,
|
||||
peft_config: LoraConfig | None,
|
||||
output_dir_path: Path,
|
||||
) -> None:
|
||||
"""Save the trained model.
|
||||
Args:
|
||||
model_obj: The model to save
|
||||
trainer: The trainer instance
|
||||
peft_config: Optional LoRA configuration
|
||||
output_dir_path: Path to save the model
|
||||
"""
|
||||
logger.info("Saving final model")
|
||||
model_obj.config.use_cache = True
|
||||
|
||||
if peft_config:
|
||||
logger.info("Merging LoRA weights with base model")
|
||||
model_obj = trainer.model.merge_and_unload()
|
||||
else:
|
||||
model_obj = trainer.model
|
||||
|
||||
save_path = output_dir_path / "merged_model"
|
||||
logger.info(f"Saving model to {save_path}")
|
||||
model_obj.save_pretrained(save_path)
|
||||
|
||||
async def _run_training(
|
||||
self,
|
||||
model: str,
|
||||
provider_config: dict[str, Any],
|
||||
peft_config: LoraConfig | None,
|
||||
config: dict[str, Any],
|
||||
output_dir_path: Path | None,
|
||||
) -> None:
|
||||
"""Run the training process with signal handling."""
|
||||
|
||||
# Setup environment variables
|
||||
setup_environment()
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers()
|
||||
|
||||
# Convert config dicts back to objects
|
||||
logger.info("Initializing configuration objects")
|
||||
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
||||
config_obj = TrainingConfig(**config)
|
||||
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config_obj.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
|
||||
# Load dataset and tokenizer
|
||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||
|
||||
# Calculate steps per epoch
|
||||
if not config_obj.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
||||
|
||||
# Setup training arguments
|
||||
training_args = self.setup_training_args(
|
||||
config_obj,
|
||||
provider_config_obj,
|
||||
device,
|
||||
output_dir_path,
|
||||
steps_per_epoch,
|
||||
)
|
||||
|
||||
# Load model
|
||||
model_obj = load_model(model, device, provider_config_obj)
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing SFTTrainer")
|
||||
trainer = SFTTrainer(
|
||||
model=model_obj,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
args=training_args,
|
||||
)
|
||||
|
||||
try:
|
||||
# Train
|
||||
logger.info("Starting training")
|
||||
trainer.train()
|
||||
logger.info("Training completed successfully")
|
||||
|
||||
# Save final model if output directory is provided
|
||||
if output_dir_path:
|
||||
self.save_model(model_obj, trainer, peft_config, output_dir_path)
|
||||
|
||||
finally:
|
||||
# Clean up resources
|
||||
logger.info("Cleaning up resources")
|
||||
if hasattr(trainer, "model"):
|
||||
evacuate_model_from_device(trainer.model, device.type)
|
||||
del trainer
|
||||
gc.collect()
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
async def train(
|
||||
self,
|
||||
model: str,
|
||||
output_dir: str | None,
|
||||
job_uuid: str,
|
||||
lora_config: LoraFinetuningConfig,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
|
||||
"""Train a model using HuggingFace's SFTTrainer"""
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
|
||||
output_dir_path = None
|
||||
if output_dir:
|
||||
output_dir_path = Path(output_dir)
|
||||
|
||||
# Track memory stats
|
||||
memory_stats = {
|
||||
"initial": get_memory_stats(device),
|
||||
"after_training": None,
|
||||
"final": None,
|
||||
}
|
||||
|
||||
# Configure LoRA
|
||||
peft_config = None
|
||||
if lora_config:
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=lora_config.alpha,
|
||||
lora_dropout=0.1,
|
||||
r=lora_config.rank,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=lora_config.lora_attn_modules,
|
||||
)
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Train in a separate process
|
||||
logger.info("Starting training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
process = multiprocessing.Process(
|
||||
target=self._run_training_sync,
|
||||
kwargs={
|
||||
"model": model,
|
||||
"provider_config": provider_config.model_dump(),
|
||||
"peft_config": peft_config,
|
||||
"config": config.model_dump(),
|
||||
"output_dir_path": output_dir_path,
|
||||
},
|
||||
)
|
||||
process.start()
|
||||
|
||||
# Monitor the process
|
||||
while process.is_alive():
|
||||
process.join(timeout=1) # Check every second
|
||||
if not process.is_alive():
|
||||
break
|
||||
|
||||
# Get the return code
|
||||
if process.exitcode != 0:
|
||||
raise RuntimeError(f"Training failed with exit code {process.exitcode}")
|
||||
|
||||
memory_stats["after_training"] = get_memory_stats(device)
|
||||
|
||||
checkpoints = []
|
||||
if output_dir_path:
|
||||
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model")
|
||||
|
||||
return memory_stats, checkpoints if checkpoints else None
|
||||
finally:
|
||||
memory_stats["final"] = get_memory_stats(device)
|
||||
gc.collect()
|
||||
|
|
@ -0,0 +1,485 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
)
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
calculate_training_steps,
|
||||
create_checkpoints,
|
||||
get_memory_stats,
|
||||
get_save_strategy,
|
||||
load_model,
|
||||
load_rows_from_dataset,
|
||||
setup_environment,
|
||||
setup_signal_handlers,
|
||||
setup_torch_device,
|
||||
split_dataset,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class HFDPOAlignmentSingleDevice:
|
||||
def __init__(
|
||||
self,
|
||||
job_uuid: str,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
):
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.job_uuid = job_uuid
|
||||
|
||||
def validate_dataset_format(self, rows: list[dict]) -> None:
|
||||
"""Validate that the dataset has the required fields for DPO training."""
|
||||
required_fields = ["prompt", "chosen", "rejected"]
|
||||
|
||||
if not rows:
|
||||
logger.warning("Dataset is empty")
|
||||
raise ValueError("Dataset is empty")
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
if not isinstance(row, dict):
|
||||
logger.warning(f"Row {i} is not a dictionary")
|
||||
raise ValueError(f"Row {i} is not a dictionary")
|
||||
|
||||
for field in required_fields:
|
||||
if field not in row:
|
||||
logger.warning(f"Row {i} missing required DPO field: {field}")
|
||||
raise ValueError(f"Row {i} missing required DPO field: {field}")
|
||||
|
||||
# Handle both string and list formats
|
||||
if field == "prompt":
|
||||
# Prompt should be a string
|
||||
if not isinstance(row[field], str):
|
||||
logger.warning(f"Row {i} field '{field}' is not a string")
|
||||
raise ValueError(f"Row {i} field '{field}' is not a string")
|
||||
if not row[field].strip():
|
||||
logger.warning(f"Row {i} field '{field}' is empty")
|
||||
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||
else:
|
||||
# chosen/rejected can be either strings or lists of messages
|
||||
if isinstance(row[field], str):
|
||||
if not row[field].strip():
|
||||
logger.warning(f"Row {i} field '{field}' is empty")
|
||||
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||
elif isinstance(row[field], list):
|
||||
if not row[field]:
|
||||
logger.warning(f"Row {i} field '{field}' is empty list")
|
||||
raise ValueError(f"Row {i} field '{field}' is empty list")
|
||||
else:
|
||||
logger.warning(f"Row {i} field '{field}' is neither string nor list")
|
||||
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
|
||||
|
||||
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
|
||||
|
||||
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
|
||||
"""Process a row in DPO format, handling both string and conversation list formats."""
|
||||
if all(field in row for field in ["prompt", "chosen", "rejected"]):
|
||||
prompt = row["prompt"]
|
||||
|
||||
# Handle chosen field - convert list to string if needed
|
||||
if isinstance(row["chosen"], list):
|
||||
# For conversation format, concatenate messages
|
||||
chosen = "\n".join(
|
||||
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["chosen"]]
|
||||
)
|
||||
else:
|
||||
chosen = row["chosen"]
|
||||
|
||||
# Handle rejected field - convert list to string if needed
|
||||
if isinstance(row["rejected"], list):
|
||||
# For conversation format, concatenate messages
|
||||
rejected = "\n".join(
|
||||
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["rejected"]]
|
||||
)
|
||||
else:
|
||||
rejected = row["rejected"]
|
||||
|
||||
return prompt, chosen, rejected
|
||||
return None, None, None
|
||||
|
||||
def _format_text_for_dpo(self, prompt: str, response: str, provider_config: HuggingFacePostTrainingConfig) -> str:
|
||||
"""Format prompt and response text based on model requirements."""
|
||||
if hasattr(provider_config, "chat_template") and provider_config.chat_template:
|
||||
# Use the chat template, supporting both {prompt}/{response} and {input}/{output}
|
||||
template = provider_config.chat_template
|
||||
# Try prompt/response first (DPO style)
|
||||
if "{prompt}" in template and "{response}" in template:
|
||||
return template.format(prompt=prompt, response=response)
|
||||
# Fall back to input/output (SFT style)
|
||||
elif "{input}" in template and "{output}" in template:
|
||||
return template.format(input=prompt, output=response)
|
||||
else:
|
||||
# If template doesn't have expected placeholders, use default
|
||||
return f"{prompt}\n{response}"
|
||||
return f"{prompt}\n{response}"
|
||||
|
||||
def _create_dataset(
|
||||
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
|
||||
) -> Dataset:
|
||||
"""Create and preprocess the dataset for DPO."""
|
||||
dpo_examples = []
|
||||
for row in rows:
|
||||
prompt, chosen, rejected = self._process_dpo_format(row)
|
||||
|
||||
if prompt and chosen and rejected:
|
||||
# Format the texts
|
||||
chosen_formatted = self._format_text_for_dpo(prompt, chosen, provider_config)
|
||||
rejected_formatted = self._format_text_for_dpo(prompt, rejected, provider_config)
|
||||
|
||||
dpo_examples.append(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"chosen": chosen_formatted,
|
||||
"rejected": rejected_formatted,
|
||||
}
|
||||
)
|
||||
|
||||
if not dpo_examples:
|
||||
raise ValueError("No valid preference examples found in dataset")
|
||||
|
||||
logger.info(f"Created DPO dataset with {len(dpo_examples)} preference pairs")
|
||||
return Dataset.from_list(dpo_examples)
|
||||
|
||||
def _preprocess_dataset(
|
||||
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
|
||||
) -> Dataset:
|
||||
"""Preprocess the dataset with tokenizer for DPO."""
|
||||
# DPOTrainer expects raw text, so we don't tokenize here
|
||||
# Just return the dataset as is
|
||||
return ds
|
||||
|
||||
def _run_training_sync(
|
||||
self,
|
||||
model: str,
|
||||
provider_config: dict[str, Any],
|
||||
dpo_config: dict[str, Any],
|
||||
config: dict[str, Any],
|
||||
output_dir_path: Path | None,
|
||||
) -> None:
|
||||
"""Synchronous wrapper for running DPO training process."""
|
||||
import asyncio
|
||||
|
||||
logger.info("Starting DPO training process with async wrapper")
|
||||
asyncio.run(
|
||||
self._run_training(
|
||||
model=model,
|
||||
provider_config=provider_config,
|
||||
dpo_config=dpo_config,
|
||||
config=config,
|
||||
output_dir_path=output_dir_path,
|
||||
)
|
||||
)
|
||||
|
||||
async def load_dataset(
|
||||
self,
|
||||
model: str,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
||||
"""Load and prepare the dataset for DPO training."""
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for DPO training")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
self.validate_dataset_format(rows)
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
logger.info(f"Initializing tokenizer for model: {model}")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||
|
||||
# Set pad token to eos token if not present
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Set padding side to left for DPO
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Set truncation side to right to keep the beginning of the sequence
|
||||
tokenizer.truncation_side = "right"
|
||||
|
||||
# Set model max length to match provider config
|
||||
tokenizer.model_max_length = provider_config.max_seq_length
|
||||
|
||||
logger.info("Tokenizer initialized successfully for DPO")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||
|
||||
# Create and preprocess dataset
|
||||
logger.info("Creating and preprocessing dataset for DPO")
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
logger.info(f"Dataset created with {len(ds)} examples")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||
|
||||
# Split dataset
|
||||
train_dataset, eval_dataset = split_dataset(ds)
|
||||
|
||||
return train_dataset, eval_dataset, tokenizer
|
||||
|
||||
def setup_training_args(
|
||||
self,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
dpo_config: DPOAlignmentConfig,
|
||||
device: torch.device,
|
||||
output_dir_path: Path | None,
|
||||
steps_per_epoch: int,
|
||||
) -> DPOConfig:
|
||||
"""Setup DPO training arguments."""
|
||||
logger.info("Configuring DPO training arguments")
|
||||
lr = 5e-7 # Lower learning rate for DPO
|
||||
if config.optimizer_config:
|
||||
lr = config.optimizer_config.lr
|
||||
logger.info(f"Using custom learning rate: {lr}")
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
data_config = config.data_config
|
||||
|
||||
# Calculate steps and get save strategy
|
||||
step_info = calculate_training_steps(steps_per_epoch, config)
|
||||
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
|
||||
|
||||
logger.info("DPO training configuration:")
|
||||
logger.info(f"- DPO beta: {dpo_config.beta}")
|
||||
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
|
||||
|
||||
# Calculate max prompt length as half of max sequence length
|
||||
max_prompt_length = provider_config.max_seq_length // 2
|
||||
|
||||
return DPOConfig(
|
||||
max_steps=step_info["max_steps"],
|
||||
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
||||
num_train_epochs=config.n_epochs,
|
||||
per_device_train_batch_size=data_config.batch_size,
|
||||
fp16=device.type == "cuda",
|
||||
bf16=False, # Causes CPU issues.
|
||||
eval_strategy=eval_strategy,
|
||||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||
save_strategy=save_strategy,
|
||||
report_to="none",
|
||||
max_length=provider_config.max_seq_length,
|
||||
max_prompt_length=max_prompt_length,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||
learning_rate=lr,
|
||||
warmup_ratio=provider_config.warmup_ratio,
|
||||
weight_decay=provider_config.weight_decay,
|
||||
remove_unused_columns=False,
|
||||
dataloader_pin_memory=provider_config.dataloader_pin_memory,
|
||||
dataloader_num_workers=provider_config.dataloader_num_workers,
|
||||
load_best_model_at_end=True if output_dir_path else False,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
logging_steps=step_info["logging_steps"],
|
||||
save_total_limit=provider_config.save_total_limit,
|
||||
# DPO specific parameters
|
||||
beta=dpo_config.beta,
|
||||
loss_type=provider_config.dpo_loss_type,
|
||||
)
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
trainer: DPOTrainer,
|
||||
output_dir_path: Path,
|
||||
) -> None:
|
||||
"""Save the trained DPO model."""
|
||||
logger.info("Saving final DPO model")
|
||||
|
||||
save_path = output_dir_path / "dpo_model"
|
||||
logger.info(f"Saving model to {save_path}")
|
||||
|
||||
# Save model and tokenizer
|
||||
trainer.save_model(str(save_path))
|
||||
|
||||
async def _run_training(
|
||||
self,
|
||||
model: str,
|
||||
provider_config: dict[str, Any],
|
||||
dpo_config: dict[str, Any],
|
||||
config: dict[str, Any],
|
||||
output_dir_path: Path | None,
|
||||
) -> None:
|
||||
"""Run the DPO training process with signal handling."""
|
||||
|
||||
# Setup environment variables
|
||||
setup_environment()
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers()
|
||||
|
||||
# Convert config dicts back to objects
|
||||
logger.info("Initializing configuration objects")
|
||||
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
||||
config_obj = TrainingConfig(**config)
|
||||
dpo_config_obj = DPOAlignmentConfig(**dpo_config)
|
||||
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config_obj.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
|
||||
# Load dataset and tokenizer
|
||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||
|
||||
# Calculate steps per epoch
|
||||
if not config_obj.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
||||
|
||||
# Setup training arguments
|
||||
training_args = self.setup_training_args(
|
||||
config_obj,
|
||||
provider_config_obj,
|
||||
dpo_config_obj,
|
||||
device,
|
||||
output_dir_path,
|
||||
steps_per_epoch,
|
||||
)
|
||||
|
||||
# Load model and reference model
|
||||
model_obj = load_model(model, device, provider_config_obj)
|
||||
ref_model = None
|
||||
if provider_config_obj.use_reference_model:
|
||||
logger.info("Loading separate reference model for DPO")
|
||||
ref_model = load_model(model, device, provider_config_obj)
|
||||
else:
|
||||
logger.info("Using shared reference model for DPO")
|
||||
|
||||
# Initialize DPO trainer
|
||||
logger.info("Initializing DPOTrainer")
|
||||
trainer = DPOTrainer(
|
||||
model=model_obj,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
try:
|
||||
# Train
|
||||
logger.info("Starting DPO training")
|
||||
trainer.train()
|
||||
logger.info("DPO training completed successfully")
|
||||
|
||||
# Save final model if output directory is provided
|
||||
if output_dir_path:
|
||||
logger.info(f"Saving model to output directory: {output_dir_path}")
|
||||
self.save_model(trainer, output_dir_path)
|
||||
logger.info("Model save completed")
|
||||
|
||||
finally:
|
||||
# Clean up resources
|
||||
logger.info("Cleaning up resources")
|
||||
if hasattr(trainer, "model"):
|
||||
evacuate_model_from_device(trainer.model, device.type)
|
||||
if ref_model:
|
||||
evacuate_model_from_device(ref_model, device.type)
|
||||
del trainer
|
||||
del ref_model
|
||||
gc.collect()
|
||||
logger.info("Cleanup completed")
|
||||
logger.info("DPO training process finishing successfully")
|
||||
|
||||
async def train(
|
||||
self,
|
||||
model: str,
|
||||
output_dir: str | None,
|
||||
job_uuid: str,
|
||||
dpo_config: DPOAlignmentConfig,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
|
||||
"""Train a model using HuggingFace's DPOTrainer"""
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
|
||||
output_dir_path = None
|
||||
if output_dir:
|
||||
output_dir_path = Path(output_dir)
|
||||
|
||||
# Track memory stats
|
||||
memory_stats = {
|
||||
"initial": get_memory_stats(device),
|
||||
"after_training": None,
|
||||
"final": None,
|
||||
}
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Train in a separate process
|
||||
logger.info("Starting DPO training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
process = multiprocessing.Process(
|
||||
target=self._run_training_sync,
|
||||
kwargs={
|
||||
"model": model,
|
||||
"provider_config": provider_config.model_dump(),
|
||||
"dpo_config": dpo_config.model_dump(),
|
||||
"config": config.model_dump(),
|
||||
"output_dir_path": output_dir_path,
|
||||
},
|
||||
)
|
||||
process.start()
|
||||
|
||||
# Monitor the process
|
||||
while process.is_alive():
|
||||
process.join(timeout=1) # Check every second
|
||||
if not process.is_alive():
|
||||
break
|
||||
|
||||
# Get the return code
|
||||
if process.exitcode != 0:
|
||||
raise RuntimeError(f"DPO training failed with exit code {process.exitcode}")
|
||||
|
||||
memory_stats["after_training"] = get_memory_stats(device)
|
||||
|
||||
checkpoints = []
|
||||
if output_dir_path:
|
||||
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model")
|
||||
|
||||
return memory_stats, checkpoints if checkpoints else None
|
||||
finally:
|
||||
memory_stats["final"] = get_memory_stats(device)
|
||||
gc.collect()
|
||||
|
|
@ -0,0 +1,269 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
def setup_environment():
|
||||
"""Setup common environment variables for training."""
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
|
||||
def bytes_to_gb(to_convert: int) -> str:
|
||||
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||
Args:
|
||||
to_convert: Memory value in bytes
|
||||
Returns:
|
||||
str: Memory value in GB formatted to 2 decimal places
|
||||
"""
|
||||
return f"{(to_convert / (1024**3)):.2f}"
|
||||
|
||||
|
||||
def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
||||
"""Get memory statistics for the given device."""
|
||||
stats = {
|
||||
"system_memory": {
|
||||
"total": bytes_to_gb(psutil.virtual_memory().total),
|
||||
"available": bytes_to_gb(psutil.virtual_memory().available),
|
||||
"used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
"percent": psutil.virtual_memory().percent,
|
||||
}
|
||||
}
|
||||
|
||||
if device.type == "cuda":
|
||||
stats["device_memory"] = {
|
||||
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
|
||||
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
|
||||
}
|
||||
elif device.type == "mps":
|
||||
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||
stats["device_memory"] = {
|
||||
"note": "MPS memory stats not directly available",
|
||||
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
}
|
||||
elif device.type == "cpu":
|
||||
# For CPU, we track process memory usage
|
||||
process = psutil.Process()
|
||||
stats["device_memory"] = {
|
||||
"process_rss": bytes_to_gb(process.memory_info().rss),
|
||||
"process_vms": bytes_to_gb(process.memory_info().vms),
|
||||
"process_percent": process.memory_percent(),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def setup_torch_device(device_str: str) -> torch.device:
|
||||
"""Initialize and validate a PyTorch device.
|
||||
This function handles device initialization and validation for different device types:
|
||||
- CUDA: Validates CUDA availability and handles device selection
|
||||
- MPS: Validates MPS availability for Apple Silicon
|
||||
- CPU: Basic validation
|
||||
- HPU: Raises error as it's not supported
|
||||
Args:
|
||||
device_str: String specifying the device ('cuda', 'cpu', 'mps')
|
||||
Returns:
|
||||
torch.device: The initialized and validated device
|
||||
Raises:
|
||||
RuntimeError: If device initialization fails or device is not supported
|
||||
"""
|
||||
try:
|
||||
device = torch.device(device_str)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
||||
|
||||
# Validate device capabilities
|
||||
if device.type == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
||||
)
|
||||
if device.index is None:
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
elif device.type == "mps":
|
||||
if not torch.backends.mps.is_available():
|
||||
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
||||
elif device.type == "hpu":
|
||||
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider"""
|
||||
try:
|
||||
all_rows = await datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
if not isinstance(all_rows.data, list):
|
||||
raise RuntimeError("Expected dataset data to be a list")
|
||||
return all_rows.data
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||
|
||||
|
||||
def load_model(
|
||||
model: str,
|
||||
device: torch.device,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> AutoModelForCausalLM:
|
||||
"""Load and initialize the model for training.
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
device: The device to load the model onto
|
||||
provider_config: Provider-specific configuration
|
||||
Returns:
|
||||
The loaded and initialized model
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
logger.info("Loading the base model")
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
||||
model_obj = AutoModelForCausalLM.from_pretrained(
|
||||
model,
|
||||
torch_dtype="auto" if device.type != "cpu" else "float32",
|
||||
quantization_config=None,
|
||||
config=model_config,
|
||||
**provider_config.model_specific_config,
|
||||
)
|
||||
# Always move model to specified device
|
||||
model_obj = model_obj.to(device)
|
||||
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
||||
return model_obj
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
||||
|
||||
|
||||
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
|
||||
"""Split dataset into train and validation sets.
|
||||
Args:
|
||||
ds: Dataset to split
|
||||
Returns:
|
||||
tuple: (train_dataset, eval_dataset)
|
||||
"""
|
||||
logger.info("Splitting dataset into train and validation sets")
|
||||
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = train_val_split["train"]
|
||||
eval_dataset = train_val_split["test"]
|
||||
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""Setup signal handlers for graceful shutdown."""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
|
||||
"""Calculate training steps and logging configuration.
|
||||
Args:
|
||||
steps_per_epoch: Number of training steps per epoch
|
||||
config: Training configuration
|
||||
Returns:
|
||||
dict: Dictionary with calculated step values
|
||||
"""
|
||||
total_steps = steps_per_epoch * config.n_epochs
|
||||
max_steps = min(config.max_steps_per_epoch, total_steps)
|
||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
||||
|
||||
logger.info("Training configuration:")
|
||||
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
||||
logger.info(f"- Total steps: {total_steps}")
|
||||
logger.info(f"- Max steps: {max_steps}")
|
||||
logger.info(f"- Logging steps: {logging_steps}")
|
||||
|
||||
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
|
||||
|
||||
|
||||
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
|
||||
"""Get save and evaluation strategy based on output directory.
|
||||
Args:
|
||||
output_dir_path: Optional path to save the model
|
||||
Returns:
|
||||
tuple: (save_strategy, eval_strategy)
|
||||
"""
|
||||
if output_dir_path:
|
||||
logger.info(f"Will save checkpoints to {output_dir_path}")
|
||||
return "epoch", "epoch"
|
||||
return "no", "no"
|
||||
|
||||
|
||||
def create_checkpoints(
|
||||
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
|
||||
) -> list[Checkpoint]:
|
||||
"""Create checkpoint objects from training output.
|
||||
Args:
|
||||
output_dir_path: Path to the training output directory
|
||||
job_uuid: Unique identifier for the training job
|
||||
model: Model identifier
|
||||
config: Training configuration
|
||||
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
|
||||
Returns:
|
||||
List of Checkpoint objects
|
||||
"""
|
||||
checkpoints = []
|
||||
|
||||
# Add checkpoint directories
|
||||
checkpoint_dirs = sorted(
|
||||
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
||||
key=lambda x: int(x.name.split("-")[1]),
|
||||
)
|
||||
|
||||
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
||||
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=checkpoint_dir.name,
|
||||
created_at=created_time,
|
||||
epoch=epoch_number,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(checkpoint_dir),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
# Add final model
|
||||
final_model_path = output_dir_path / final_model_name
|
||||
if final_model_path.exists():
|
||||
training_type = "sft" if final_model_name == "merged_model" else "dpo"
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{model}-{training_type}-{config.n_epochs}",
|
||||
created_at=datetime.now(UTC),
|
||||
epoch=config.n_epochs,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(final_model_path),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
# post_training api and the torchtune provider is still experimental and under heavy development
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
impl = TorchtunePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from torchtune import training
|
||||
from torchtune.models import convert_weights
|
||||
from torchtune.training.checkpointing._utils import (
|
||||
ADAPTER_CONFIG_FNAME,
|
||||
ADAPTER_MODEL_FNAME,
|
||||
REPO_ID_FNAME,
|
||||
SUFFIXES_TO_NOT_COPY,
|
||||
ModelType,
|
||||
copy_files,
|
||||
safe_torch_load,
|
||||
)
|
||||
from torchtune.utils._logging import get_logger
|
||||
|
||||
logger = get_logger("DEBUG")
|
||||
|
||||
|
||||
class TorchtuneCheckpointer:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
training_algorithm: str,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_files: list[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
):
|
||||
# Fail fast if ``checkpoint_files`` is invalid
|
||||
# TODO: support loading more than one file
|
||||
if len(checkpoint_files) != 1:
|
||||
raise ValueError(
|
||||
"Currently we only support reading from a single torchtune checkpoint file. "
|
||||
f"Got {len(checkpoint_files)} files instead."
|
||||
)
|
||||
self._checkpoint_file = checkpoint_files[0]
|
||||
self._model_id = model_id
|
||||
self._training_algorithm = training_algorithm
|
||||
self._checkpoint_dir = Path(checkpoint_dir)
|
||||
self._model_type = ModelType[model_type]
|
||||
self._output_dir = output_dir
|
||||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||
|
||||
def load_checkpoint(self) -> dict[str, Any]:
|
||||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_meta_to_tune,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
|
||||
else:
|
||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
|
||||
|
||||
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||
if self._model_type == ModelType.LLAMA3_2:
|
||||
logger.info(
|
||||
"Identified model_type = Llama3_2. Ignoring output.weight in"
|
||||
" checkpoint in favor of the tok_embedding.weight"
|
||||
" tied weights."
|
||||
)
|
||||
state_dict[training.MODEL_KEY].pop("output.weight")
|
||||
|
||||
return state_dict
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
state_dict: dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str | None = None,
|
||||
) -> str:
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
if checkpoint_format == "meta" or checkpoint_format is None:
|
||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||
elif checkpoint_format == "huggingface":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
self._save_hf_format_checkpoint(model_file_path, state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported checkpoint format: {format}")
|
||||
return str(model_file_path)
|
||||
|
||||
def _save_meta_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: dict[str, Any],
|
||||
adapter_only: bool = False,
|
||||
) -> None:
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_tune_to_meta,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
|
||||
else:
|
||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
|
||||
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
|
||||
|
||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
|
||||
|
||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||
|
||||
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||
logger.info(
|
||||
"Model checkpoint of size "
|
||||
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {model_file_name}"
|
||||
)
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
adapter_file_path = model_file_path / "adapter"
|
||||
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {adapter_file_name}"
|
||||
)
|
||||
|
||||
elif adapter_only:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
def _save_hf_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
||||
|
||||
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
|
||||
# This json file is produced and saved in the download step.
|
||||
# contents are {"repo_id": "some_model/some_model_version"}
|
||||
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path) as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
# TODO: saving it "as is" is a requirement because, if we only save with
|
||||
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
||||
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
||||
# it is an easy way to distinguish the adapters. Ideally we should save only one.
|
||||
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(state_dict[training.ADAPTER_KEY], output_path)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
|
||||
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
num_heads=config["num_attention_heads"],
|
||||
num_kv_heads=config["num_key_value_heads"],
|
||||
dim=config["hidden_size"],
|
||||
head_dim=config.get("head_dim", None),
|
||||
)
|
||||
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_path.with_suffix(".safetensors")
|
||||
save_file(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
output_path,
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
if training.ADAPTER_CONFIG in state_dict:
|
||||
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
|
||||
adapter_config=state_dict[training.ADAPTER_CONFIG],
|
||||
base_model_name_or_path=self.repo_id,
|
||||
)
|
||||
|
||||
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
|
||||
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
||||
# So its easy to run inference with the model using this epoch's checkpoint
|
||||
copy_files(
|
||||
self._checkpoint_dir.parent,
|
||||
model_file_path,
|
||||
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||
)
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||
from torchtune.models.llama3 import llama3_tokenizer
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: BuildLoraModelCallable
|
||||
tokenizer_type: BuildTokenizerCallable
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama3.1-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_1_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
}
|
||||
|
||||
DATA_FORMATS: dict[str, Transform] = {
|
||||
"instruct": InputOutputToMessages,
|
||||
"dialog": ShareGPTToMessages,
|
||||
}
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
raise ValueError(f"Model {model_id} is not supported.")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_definition(
|
||||
model_id: str,
|
||||
) -> BuildLoraModelCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "model_definition"):
|
||||
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||
return model_config.model_definition
|
||||
|
||||
|
||||
async def get_tokenizer_type(
|
||||
model_id: str,
|
||||
) -> BuildTokenizerCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "tokenizer_type"):
|
||||
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||
return model_config.tokenizer_type
|
||||
|
||||
|
||||
async def get_checkpointer_model_type(
|
||||
model_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
||||
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||
"""
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "checkpoint_type"):
|
||||
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||
return model_config.checkpoint_type
|
||||
|
||||
|
||||
async def get_data_transform(data_format: DatasetFormat) -> Transform:
|
||||
return DATA_FORMATS[data_format.value]
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: int | None = None
|
||||
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
||||
def llama_stack_instruct_to_torchtune_instruct(
|
||||
sample: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
|
||||
"Invalid input row"
|
||||
)
|
||||
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
|
||||
|
||||
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
|
||||
input_message = input_messages[0]
|
||||
|
||||
assert "content" in input_message, "content not found in input message"
|
||||
input = input_message["content"]
|
||||
output = sample[ColumnName.expected_answer.value]
|
||||
|
||||
return {
|
||||
"input": input,
|
||||
"output": output,
|
||||
}
|
||||
|
||||
|
||||
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
assert ColumnName.dialog.value in sample, "Invalid input row"
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
dialog = json.loads(sample[ColumnName.dialog.value])
|
||||
|
||||
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
||||
roles = []
|
||||
conversations = []
|
||||
for message in dialog:
|
||||
assert "role" in message and "content" in message, "role and content must in message"
|
||||
roles.append(message["role"])
|
||||
conversations.append({"from": role_map[message["role"]], "value": message["content"]})
|
||||
|
||||
assert roles[0] == "user", "first message must be from user"
|
||||
assert "assistant" in roles, "at least 1 message should be from assistant"
|
||||
|
||||
return {"conversations": conversations}
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||
from torchtune.data._messages import validate_messages
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
|
||||
llama_stack_chat_to_torchtune_chat,
|
||||
llama_stack_instruct_to_torchtune_instruct,
|
||||
)
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
rows: list[dict[str, Any]],
|
||||
message_transform: Transform,
|
||||
model_transform: Transform,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
self._rows = rows
|
||||
self._message_transform = message_transform
|
||||
self._model_transform = model_transform
|
||||
self._dataset_type = dataset_type
|
||||
|
||||
def __len__(self):
|
||||
return len(self._rows)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Any]:
|
||||
sample = self._rows[index]
|
||||
return self._prepare_sample(sample)
|
||||
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]:
|
||||
if self._dataset_type == "instruct":
|
||||
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||
elif self._dataset_type == "dialog":
|
||||
sample = llama_stack_chat_to_torchtune_chat(sample)
|
||||
else:
|
||||
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
|
||||
transformed_sample = self._message_transform(sample)
|
||||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
||||
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
|
||||
|
||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||
keys_str = ", ".join(tokenized_dict.keys())
|
||||
error_message = (
|
||||
f"model_transform returned the following keys: {keys_str}. Must return 'tokens' and 'mask' as keys."
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
|
||||
tokenized_dict["labels"] = list(
|
||||
np.where(
|
||||
tokenized_dict["mask"],
|
||||
CROSS_ENTROPY_IGNORE_IDX,
|
||||
tokenized_dict["tokens"],
|
||||
)
|
||||
)
|
||||
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
|
||||
|
||||
return tokenized_dict
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
LoraFinetuningConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
RESOURCES_STATS = "resources_stats"
|
||||
|
||||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
self._scheduler = Scheduler()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
metadata=resources_stats,
|
||||
)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: AlgorithmConfig | None,
|
||||
) -> PostTrainingJob:
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting Lora finetuning")
|
||||
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
training_config,
|
||||
hyperparam_search_config,
|
||||
logger_config,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
await recipe.setup()
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("Lora finetuning completed")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,588 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
|
||||
log = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
||||
# - compile
|
||||
# - activation offloading
|
||||
|
||||
# Resume from checkpoint hasn't been supported yet
|
||||
# Validation hasn't been supported yet
|
||||
|
||||
# Currently logging only logs limited training metrics to local disk
|
||||
# will figure out more loggings and how it works with telemetry in future PRs
|
||||
|
||||
_checkpointer: TorchtuneCheckpointer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device()
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
self.model_id = model
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
hf_repo = model.huggingface_repo or f"meta-llama/{model.descriptor()}"
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download the model using `huggingface-cli download {hf_repo} --local-dir ~/.llama/{model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
if checkpoint_dir and checkpoint_dir != "null":
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
else:
|
||||
model_obj = resolve_model(self.model_id)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||
self.checkpoint_dir = model_checkpoint_dir(model_obj)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
self._checkpoint_format = config.checkpoint_format
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.epochs_run = 0
|
||||
self.total_epochs = training_config.n_epochs
|
||||
self._data_format = training_config.data_config.data_format
|
||||
self._shuffle = training_config.data_config.shuffle
|
||||
self._batch_size = training_config.data_config.batch_size
|
||||
self._train_on_input = training_config.data_config.train_on_input
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||
self.global_step = 0
|
||||
|
||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
|
||||
self._enable_activation_checkpointing = False
|
||||
self._enable_activation_offloading = False
|
||||
if training_config.efficiency_config:
|
||||
if training_config.efficiency_config.enable_activation_checkpointing:
|
||||
self._enable_activation_checkpointing = (
|
||||
training_config.efficiency_config.enable_activation_checkpointing
|
||||
)
|
||||
if training_config.efficiency_config.enable_activation_offloading:
|
||||
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
|
||||
try:
|
||||
# List all files in the given directory
|
||||
files = os.listdir(checkpoint_dir)
|
||||
# Filter files that end with .pth
|
||||
pth_files = [file for file in files if file.endswith(".pth")]
|
||||
return pth_files
|
||||
except FileNotFoundError:
|
||||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||
|
||||
self._checkpointer = TorchtuneCheckpointer(
|
||||
model_id=self.model_id,
|
||||
training_algorithm="sft",
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||
output_dir=self._output_dir,
|
||||
model_type=await utils.get_checkpointer_model_type(self.model_id),
|
||||
)
|
||||
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||
return checkpoint_dict
|
||||
|
||||
async def setup(self) -> None:
|
||||
checkpoint_dict = await self.load_checkpoint()
|
||||
|
||||
self._model = await self._setup_model(
|
||||
enable_activation_checkpointing=self._enable_activation_checkpointing,
|
||||
enable_activation_offloading=self._enable_activation_offloading,
|
||||
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
|
||||
lora_weights_state_dict=None,
|
||||
)
|
||||
log.info(f"Model is initialized with precision {self._dtype}.")
|
||||
|
||||
self._tokenizer = await self._setup_tokenizer()
|
||||
log.info("Tokenizer is initialized.")
|
||||
|
||||
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
|
||||
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
||||
log.info("Optimizer is initialized.")
|
||||
|
||||
self._loss_fn = CEWithChunkedOutputLoss()
|
||||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||
log.info("Loss is initialized.")
|
||||
|
||||
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
|
||||
self._training_sampler, self._training_dataloader = await self._setup_data(
|
||||
dataset_id=self.training_config.data_config.dataset_id,
|
||||
tokenizer=self._tokenizer,
|
||||
shuffle=self._shuffle,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
|
||||
if self.training_config.data_config.validation_dataset_id:
|
||||
_, self._validation_dataloader = await self._setup_data(
|
||||
dataset_id=self.training_config.data_config.validation_dataset_id,
|
||||
tokenizer=self._tokenizer,
|
||||
shuffle=False,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
|
||||
log.info("Dataset and Sampler are initialized.")
|
||||
|
||||
# Number of training steps in each epoch depends on the number of batches produced
|
||||
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||
# for logging and tracking training state. This should be computed after the dataloader
|
||||
# has been setup
|
||||
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
|
||||
self._steps_per_epoch = self.max_steps_per_epoch
|
||||
self.global_step = self.epochs_run * self._steps_per_epoch
|
||||
|
||||
# Learning rate scheduler can only be set up after number of steps
|
||||
# has been computed
|
||||
self._lr_scheduler = await self._setup_lr_scheduler(
|
||||
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
|
||||
num_training_steps=self.total_epochs * self._steps_per_epoch,
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
log.info("Learning rate scheduler is initialized.")
|
||||
|
||||
# Used to ignore labels for loss computation
|
||||
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||
|
||||
def _log_memory_stats(self):
|
||||
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
|
||||
if self._device.type == "cpu":
|
||||
return
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
|
||||
async def _setup_model(
|
||||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
enable_activation_offloading: bool,
|
||||
base_model_state_dict: dict[str, Any],
|
||||
lora_weights_state_dict: dict[str, Any] | None = None,
|
||||
) -> nn.Module:
|
||||
self._lora_rank = self.algorithm_config.rank
|
||||
self._lora_alpha = self.algorithm_config.alpha
|
||||
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
|
||||
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
|
||||
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
|
||||
self._use_dora = self.algorithm_config.use_dora or False
|
||||
|
||||
with training.set_default_dtype(self._dtype), self._device:
|
||||
model_type = await utils.get_model_definition(self.model_id)
|
||||
model = model_type(
|
||||
lora_attn_modules=self._lora_attn_modules,
|
||||
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||
apply_lora_to_output=self._apply_lora_to_output,
|
||||
lora_rank=self._lora_rank,
|
||||
lora_alpha=self._lora_alpha,
|
||||
quantize_base=False,
|
||||
use_dora=self._use_dora,
|
||||
)
|
||||
|
||||
self.adapter_params = get_adapter_params(model)
|
||||
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
|
||||
|
||||
set_trainable_params(model, self.adapter_params)
|
||||
|
||||
if enable_activation_checkpointing:
|
||||
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
|
||||
|
||||
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
|
||||
|
||||
# This is for any adapters that need to be initialized after base weights
|
||||
# have been loaded (e.g. DoRA).
|
||||
if self._is_dora:
|
||||
for m in model.modules():
|
||||
if hasattr(m, "initialize_dora_magnitude"):
|
||||
m.initialize_dora_magnitude()
|
||||
if lora_weights_state_dict:
|
||||
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
|
||||
else:
|
||||
lora_missing, lora_unexpected = None, None
|
||||
validate_missing_and_unexpected_for_lora(
|
||||
lora_attn_modules=self._lora_attn_modules,
|
||||
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||
apply_lora_to_output=self._apply_lora_to_output,
|
||||
base_missing=base_missing,
|
||||
base_unexpected=base_unexpected,
|
||||
lora_missing=lora_missing,
|
||||
lora_unexpected=lora_unexpected,
|
||||
)
|
||||
|
||||
# Validate model adapter params were loaded in with the expected dtype
|
||||
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
|
||||
|
||||
# activation offloading
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||
|
||||
self._log_memory_stats()
|
||||
|
||||
return model
|
||||
|
||||
async def _setup_tokenizer(
|
||||
self,
|
||||
) -> Llama3Tokenizer:
|
||||
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
||||
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
|
||||
return tokenizer_type(path=tokenizer_path)
|
||||
|
||||
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
||||
optimizer = torch.optim.AdamW(
|
||||
params=self._model.parameters(),
|
||||
lr=optimizer_config.lr,
|
||||
betas=(0.9, 0.95),
|
||||
eps=1e-8,
|
||||
weight_decay=0.1,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
async def _setup_data(
|
||||
self,
|
||||
dataset_id: str,
|
||||
tokenizer: Llama3Tokenizer,
|
||||
shuffle: bool,
|
||||
batch_size: int,
|
||||
) -> tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
|
||||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.data
|
||||
|
||||
# TODO (xiyan): validate dataset schema
|
||||
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
||||
data_transform = await utils.get_data_transform(self._data_format)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=data_transform(train_on_input=self._train_on_input),
|
||||
model_transform=tokenizer,
|
||||
dataset_type=self._data_format.value,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
ds,
|
||||
num_replicas=1,
|
||||
rank=0,
|
||||
shuffle=shuffle,
|
||||
seed=0,
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset=ds,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
# dropping last avoids shape issues with compile + flex attention
|
||||
drop_last=True,
|
||||
collate_fn=(
|
||||
partial(
|
||||
padded_collate_sft,
|
||||
padding_idx=self._tokenizer.pad_id,
|
||||
ignore_idx=self._loss_fn.ignore_index,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return sampler, dataloader
|
||||
|
||||
async def _setup_lr_scheduler(
|
||||
self,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
last_epoch: int,
|
||||
) -> Optimizer:
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self._optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
last_epoch=last_epoch,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
async def save_checkpoint(self, epoch: int) -> str:
|
||||
ckpt_dict = {}
|
||||
|
||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
|
||||
|
||||
# Construct the full state dict with LoRA weights merged into base LLM weights
|
||||
# Move to CPU to avoid a copy on GPU
|
||||
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||
|
||||
merged_state_dict = get_merged_lora_ckpt(
|
||||
state_dict,
|
||||
rank=self._lora_rank,
|
||||
alpha=self._lora_alpha,
|
||||
)
|
||||
|
||||
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
|
||||
|
||||
adapter_config = {
|
||||
"r": self._lora_rank,
|
||||
"lora_alpha": self._lora_alpha,
|
||||
"target_modules": get_lora_module_names(
|
||||
self._lora_attn_modules,
|
||||
self._apply_lora_to_mlp,
|
||||
self._apply_lora_to_output,
|
||||
),
|
||||
"peft_type": "LORA",
|
||||
}
|
||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||
|
||||
return self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
checkpoint_format=self._checkpoint_format,
|
||||
)
|
||||
|
||||
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
with self.activations_handling_ctx:
|
||||
logits = self._model(**batch)
|
||||
|
||||
# Shift labels to compute loss
|
||||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
|
||||
if not isinstance(logits, list):
|
||||
labels = labels.reshape(-1)
|
||||
logits = logits.reshape(-1, logits.size(-1))
|
||||
|
||||
loss = self._loss_fn(logits, labels)
|
||||
|
||||
# free logits otherwise it peaks backward memory
|
||||
del logits
|
||||
|
||||
return loss
|
||||
|
||||
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
t0 = time.perf_counter()
|
||||
running_loss: float = 0.0
|
||||
num_tokens = 0
|
||||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats: dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
|
||||
self._training_sampler.set_epoch(curr_epoch)
|
||||
loss_to_log = 0.0
|
||||
|
||||
pbar = tqdm(total=self._steps_per_epoch)
|
||||
for idx, batch in enumerate(self._training_dataloader):
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
|
||||
):
|
||||
break
|
||||
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||
num_tokens += current_num_tokens
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss.detach().item()
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||
training.scale_grads(self._model, 1 / num_tokens)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self._model.parameters(),
|
||||
max_norm=float(self._clip_grad_norm),
|
||||
)
|
||||
self._optimizer.step()
|
||||
self._optimizer.zero_grad(set_to_none=True)
|
||||
self._lr_scheduler.step()
|
||||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||
|
||||
time_per_step = time.perf_counter() - t0
|
||||
log_dict = {
|
||||
"loss": loss_to_log,
|
||||
"lr": self._optimizer.param_groups[0]["lr"],
|
||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
}
|
||||
|
||||
self._log_memory_stats()
|
||||
|
||||
if self._clip_grad_norm is not None:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
|
||||
metric_logger.log_dict(
|
||||
log_dict,
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0.0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(UTC),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
)
|
||||
if self.training_config.data_config.validation_dataset_id:
|
||||
validation_loss, perplexity = await self.validation()
|
||||
training_metrics = PostTrainingMetric(
|
||||
epoch=curr_epoch,
|
||||
train_loss=loss_to_log,
|
||||
validation_loss=validation_loss,
|
||||
perplexity=perplexity,
|
||||
)
|
||||
checkpoint.training_metrics = training_metrics
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
# clean up the memory after training finishes
|
||||
evacuate_model_from_device(self._model, self._device.type)
|
||||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
async def validation(self) -> tuple[float, float]:
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
log.info("Starting validation...")
|
||||
pbar = tqdm(total=len(self._validation_dataloader))
|
||||
for idx, batch in enumerate(self._validation_dataloader):
|
||||
if idx == self.max_validation_steps:
|
||||
break
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
loss = await self._loss_step(batch) * num_tokens
|
||||
|
||||
total_loss += loss
|
||||
total_tokens += num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"validation step: {idx}")
|
||||
|
||||
mean_loss = total_loss / total_tokens
|
||||
perplexity = torch.exp(torch.tensor(mean_loss))
|
||||
|
||||
return mean_loss, perplexity.item()
|
||||
5
src/llama_stack/providers/inline/safety/__init__.py
Normal file
5
src/llama_stack/providers/inline/safety/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeshield.cs import CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"code-scanner",
|
||||
"code-shield",
|
||||
]
|
||||
|
||||
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
|
||||
raise ValueError(
|
||||
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
||||
violation = None
|
||||
if result.is_insecure:
|
||||
violation = SafetyViolation(
|
||||
violation_level=(ViolationLevel.ERROR),
|
||||
user_message="Sorry, I found security concerns in the code.",
|
||||
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
|
||||
categories = {}
|
||||
category_scores = {}
|
||||
category_applied_input_types = {}
|
||||
|
||||
flagged = scan_result.is_insecure
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
||||
if scan_result.is_insecure:
|
||||
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
|
||||
categories = dict.fromkeys(pattern_ids, True)
|
||||
category_scores = dict.fromkeys(pattern_ids, 1.0)
|
||||
category_applied_input_types = {key: ["text"] for key in pattern_ids}
|
||||
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
|
||||
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
|
||||
|
||||
return ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_scores=category_scores,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
if model is None:
|
||||
raise ValueError("Code scanner moderation requires a model identifier.")
|
||||
|
||||
inputs = input if isinstance(input, list) else [input]
|
||||
results = []
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
for text_input in inputs:
|
||||
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
||||
try:
|
||||
scan_result = await CodeShield.scan_code(text_input)
|
||||
moderation_result = self.get_moderation_object_results(scan_result)
|
||||
except Exception as e:
|
||||
log.error(f"CodeShield.scan_code failed: {e}")
|
||||
# create safe fallback response on scanner failure to avoid blocking legitimate requests
|
||||
moderation_result = ModerationObjectResults(
|
||||
flagged=False,
|
||||
categories={},
|
||||
category_scores={},
|
||||
category_applied_input_types={},
|
||||
user_message=None,
|
||||
metadata={"scanner_error": str(e)},
|
||||
)
|
||||
results.append(moderation_result)
|
||||
|
||||
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = LlamaGuardSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
@ -0,0 +1,492 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
|
||||
CAT_VIOLENT_CRIMES = "Violent Crimes"
|
||||
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
|
||||
CAT_SEX_CRIMES = "Sex Crimes"
|
||||
CAT_CHILD_EXPLOITATION = "Child Exploitation"
|
||||
CAT_DEFAMATION = "Defamation"
|
||||
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
|
||||
CAT_PRIVACY = "Privacy"
|
||||
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
|
||||
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
|
||||
CAT_HATE = "Hate"
|
||||
CAT_SELF_HARM = "Self-Harm"
|
||||
CAT_SEXUAL_CONTENT = "Sexual Content"
|
||||
CAT_ELECTIONS = "Elections"
|
||||
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
|
||||
|
||||
|
||||
SAFETY_CATEGORIES_TO_CODE_MAP = {
|
||||
CAT_VIOLENT_CRIMES: "S1",
|
||||
CAT_NON_VIOLENT_CRIMES: "S2",
|
||||
CAT_SEX_CRIMES: "S3",
|
||||
CAT_CHILD_EXPLOITATION: "S4",
|
||||
CAT_DEFAMATION: "S5",
|
||||
CAT_SPECIALIZED_ADVICE: "S6",
|
||||
CAT_PRIVACY: "S7",
|
||||
CAT_INTELLECTUAL_PROPERTY: "S8",
|
||||
CAT_INDISCRIMINATE_WEAPONS: "S9",
|
||||
CAT_HATE: "S10",
|
||||
CAT_SELF_HARM: "S11",
|
||||
CAT_SEXUAL_CONTENT: "S12",
|
||||
CAT_ELECTIONS: "S13",
|
||||
CAT_CODE_INTERPRETER_ABUSE: "S14",
|
||||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
CAT_VIOLENT_CRIMES,
|
||||
CAT_NON_VIOLENT_CRIMES,
|
||||
CAT_SEX_CRIMES,
|
||||
CAT_CHILD_EXPLOITATION,
|
||||
CAT_DEFAMATION,
|
||||
CAT_SPECIALIZED_ADVICE,
|
||||
CAT_PRIVACY,
|
||||
CAT_INTELLECTUAL_PROPERTY,
|
||||
CAT_INDISCRIMINATE_WEAPONS,
|
||||
CAT_HATE,
|
||||
CAT_SELF_HARM,
|
||||
CAT_SEXUAL_CONTENT,
|
||||
CAT_ELECTIONS,
|
||||
]
|
||||
|
||||
# accept both CoreModelId and huggingface repo id
|
||||
LLAMA_GUARD_MODEL_IDS = {
|
||||
CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B",
|
||||
"meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B",
|
||||
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
CoreModelId.llama_guard_4_12b.value: "meta-llama/Llama-Guard-4-12B",
|
||||
"meta-llama/Llama-Guard-4-12B": "meta-llama/Llama-Guard-4-12B",
|
||||
}
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
|
||||
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
# Llama Guard 4 uses the same categories as Llama Guard 3
|
||||
# source: https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard4/12B/MODEL_CARD.md
|
||||
"meta-llama/Llama-Guard-4-12B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
}
|
||||
|
||||
|
||||
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
|
||||
|
||||
SAFETY_CATEGORIES = """
|
||||
|
||||
<BEGIN UNSAFE CONTENT CATEGORIES>
|
||||
$categories
|
||||
<END UNSAFE CONTENT CATEGORIES>
|
||||
"""
|
||||
|
||||
PROMPT_CONVERSATION = """
|
||||
<BEGIN CONVERSATION>
|
||||
|
||||
$conversations
|
||||
|
||||
<END CONVERSATION>
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_INSTRUCTIONS = """
|
||||
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
|
||||
- First line must read 'safe' or 'unsafe'.
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
||||
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
model_id = shield.provider_resource_id
|
||||
if not model_id:
|
||||
raise ValueError("Llama Guard shield must have a model id")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
# LlamaGuard doesn't need to do anything special for unregistration
|
||||
# The routing table handles the removal from the registry
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Unknown shield {shield_id}")
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != "user":
|
||||
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
|
||||
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
model_id = shield.provider_resource_id
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
if model_id in LLAMA_GUARD_MODEL_IDS:
|
||||
# Use the mapped model for categories but the original model_id for inference
|
||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
|
||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
||||
else:
|
||||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
impl = LlamaGuardShield(
|
||||
model=model_id,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
if model is None:
|
||||
raise ValueError("Llama Guard moderation requires a model identifier.")
|
||||
|
||||
if isinstance(input, list):
|
||||
messages = input.copy()
|
||||
else:
|
||||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [OpenAIUserMessageParam(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
if model in LLAMA_GUARD_MODEL_IDS:
|
||||
# Use the mapped model for categories but the original model_id for inference
|
||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
|
||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
||||
else:
|
||||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
impl = LlamaGuardShield(
|
||||
model=model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run_moderation(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: list[str] | None = None,
|
||||
safety_categories: list[str] | None = None,
|
||||
):
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
if safety_categories is None:
|
||||
safety_categories = []
|
||||
|
||||
assert len(excluded_categories) == 0 or all(
|
||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
||||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.safety_categories = safety_categories
|
||||
|
||||
def check_unsafe_response(self, response: str) -> str | None:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
# extracts the unsafe code
|
||||
extracted = match.group(1)
|
||||
return extracted
|
||||
|
||||
return None
|
||||
|
||||
def get_safety_categories(self) -> list[str]:
|
||||
excluded_categories = self.excluded_categories
|
||||
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
||||
excluded_categories = []
|
||||
|
||||
final_categories = []
|
||||
|
||||
all_categories = self.safety_categories
|
||||
for cat in all_categories:
|
||||
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
|
||||
if cat_code in excluded_categories:
|
||||
continue
|
||||
final_categories.append(f"{cat_code}: {cat}.")
|
||||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
raise ValueError("Messages must start with user")
|
||||
|
||||
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
||||
messages = messages[1:]
|
||||
|
||||
return messages
|
||||
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
else:
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
return OpenAIUserMessageParam(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = m.content
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str) or isinstance(c, TextContentItem):
|
||||
content.append(c)
|
||||
elif isinstance(c, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append(c)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(OpenAIUserMessageParam(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
prompt = []
|
||||
if most_recent_img is not None:
|
||||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return OpenAIUserMessageParam(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}" for m in messages]
|
||||
)
|
||||
return PROMPT_TEMPLATE.substitute(
|
||||
agent_type=messages[-1].role.capitalize(),
|
||||
categories=categories_str,
|
||||
conversations=conversations_str,
|
||||
)
|
||||
|
||||
def get_shield_response(self, response: str) -> RunShieldResponse:
|
||||
response = response.strip()
|
||||
if response == SAFE_RESPONSE:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message=CANNED_RESPONSE_TEXT,
|
||||
metadata={"violation_type": unsafe_code},
|
||||
),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_moderation_object(content)
|
||||
|
||||
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
|
||||
"""Create a ModerationObject for either safe or unsafe content.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
|
||||
|
||||
Returns:
|
||||
ModerationObject with appropriate configuration
|
||||
"""
|
||||
# Set default values for safe case
|
||||
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
flagged = False
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
||||
# Handle unsafe case
|
||||
if unsafe_code:
|
||||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||
if invalid_codes:
|
||||
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
# just returning safe object, as we don't know what the invalid codes can map to
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list]
|
||||
|
||||
# Update categories for unsafe content
|
||||
categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
category_scores = {
|
||||
k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
|
||||
}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
|
||||
}
|
||||
flagged = True
|
||||
user_message = CANNED_RESPONSE_TEXT
|
||||
metadata = {"violation_type": unsafe_code_list}
|
||||
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||
"""Check if content is safe based on response and unsafe code."""
|
||||
if response.strip().lower().startswith(SAFE_RESPONSE):
|
||||
return True
|
||||
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_moderation_object(self, response: str) -> ModerationObject:
|
||||
response = response.strip()
|
||||
if self.is_content_safe(response):
|
||||
return self.create_moderation_object(self.model)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if not unsafe_code:
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
if self.is_content_safe(response, unsafe_code):
|
||||
return self.create_moderation_object(self.model)
|
||||
else:
|
||||
return self.create_moderation_object(self.model, unsafe_code)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import PromptGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class PromptGuardType(Enum):
|
||||
injection = "injection"
|
||||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class PromptGuardConfig(BaseModel):
|
||||
guard_type: str = PromptGuardType.injection.value
|
||||
|
||||
@classmethod
|
||||
@field_validator("guard_type")
|
||||
def validate_guard_type(cls, v):
|
||||
if v not in [t.value for t in PromptGuardType]:
|
||||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ShieldStore,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
shield_store: ShieldStore
|
||||
|
||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
self.shield = PromptGuardShield(model_dir, self.config)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Unknown shield {shield_id}")
|
||||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
||||
|
||||
|
||||
class PromptGuardShield:
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
config: PromptGuardConfig,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
assert model_dir is not None, "Must provide a model directory for prompt injection shield"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
self.config = config
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
|
||||
self.device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
|
||||
# load model and tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_content_as_str(message.content)
|
||||
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs[0]
|
||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||
score_embedded = probabilities[0, 1].item()
|
||||
score_malicious = probabilities[0, 2].item()
|
||||
log.info(
|
||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||
)
|
||||
|
||||
violation = None
|
||||
if self.config.guard_type == PromptGuardType.injection.value and (
|
||||
score_embedded + score_malicious > self.threshold
|
||||
):
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="Sorry, I cannot do this.",
|
||||
metadata={
|
||||
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
},
|
||||
)
|
||||
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="Sorry, I cannot do this.",
|
||||
metadata={
|
||||
"violation_type": f"prompt_injection:malicious={score_malicious}",
|
||||
},
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
5
src/llama_stack/providers/inline/scoring/__init__.py
Normal file
5
src/llama_stack/providers/inline/scoring/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
25
src/llama_stack/providers/inline/scoring/basic/__init__.py
Normal file
25
src/llama_stack/providers/inline/scoring/basic/__init__.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
impl = BasicScoringImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
14
src/llama_stack/providers/inline/scoring/basic/config.py
Normal file
14
src/llama_stack/providers/inline/scoring/basic/config.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
126
src/llama_stack/providers/inline/scoring/basic/scoring.py
Normal file
126
src/llama_stack/providers/inline/scoring/basic/scoring.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringResult,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||
RegexParserMathResponseScoringFn,
|
||||
)
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
Scoring,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: BasicScoringConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_fn_id_impls = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for fn in FIXED_FNS:
|
||||
impl = fn()
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
raise NotImplementedError("Save results dataset not implemented yet")
|
||||
|
||||
return ScoreBatchResponse(
|
||||
results=res.results,
|
||||
)
|
||||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
results=res,
|
||||
)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.docvqa import docvqa
|
||||
|
||||
CONTRACTIONS = {
|
||||
"aint": "ain't",
|
||||
"arent": "aren't",
|
||||
"cant": "can't",
|
||||
"couldve": "could've",
|
||||
"couldnt": "couldn't",
|
||||
"couldn'tve": "couldn't've",
|
||||
"couldnt've": "couldn't've",
|
||||
"didnt": "didn't",
|
||||
"doesnt": "doesn't",
|
||||
"dont": "don't",
|
||||
"hadnt": "hadn't",
|
||||
"hadnt've": "hadn't've",
|
||||
"hadn'tve": "hadn't've",
|
||||
"hasnt": "hasn't",
|
||||
"havent": "haven't",
|
||||
"hed": "he'd",
|
||||
"hed've": "he'd've",
|
||||
"he'dve": "he'd've",
|
||||
"hes": "he's",
|
||||
"howd": "how'd",
|
||||
"howll": "how'll",
|
||||
"hows": "how's",
|
||||
"Id've": "I'd've",
|
||||
"I'dve": "I'd've",
|
||||
"Im": "I'm",
|
||||
"Ive": "I've",
|
||||
"isnt": "isn't",
|
||||
"itd": "it'd",
|
||||
"itd've": "it'd've",
|
||||
"it'dve": "it'd've",
|
||||
"itll": "it'll",
|
||||
"let's": "let's",
|
||||
"maam": "ma'am",
|
||||
"mightnt": "mightn't",
|
||||
"mightnt've": "mightn't've",
|
||||
"mightn'tve": "mightn't've",
|
||||
"mightve": "might've",
|
||||
"mustnt": "mustn't",
|
||||
"mustve": "must've",
|
||||
"neednt": "needn't",
|
||||
"notve": "not've",
|
||||
"oclock": "o'clock",
|
||||
"oughtnt": "oughtn't",
|
||||
"ow's'at": "'ow's'at",
|
||||
"'ows'at": "'ow's'at",
|
||||
"'ow'sat": "'ow's'at",
|
||||
"shant": "shan't",
|
||||
"shed've": "she'd've",
|
||||
"she'dve": "she'd've",
|
||||
"she's": "she's",
|
||||
"shouldve": "should've",
|
||||
"shouldnt": "shouldn't",
|
||||
"shouldnt've": "shouldn't've",
|
||||
"shouldn'tve": "shouldn't've",
|
||||
"somebody'd": "somebodyd",
|
||||
"somebodyd've": "somebody'd've",
|
||||
"somebody'dve": "somebody'd've",
|
||||
"somebodyll": "somebody'll",
|
||||
"somebodys": "somebody's",
|
||||
"someoned": "someone'd",
|
||||
"someoned've": "someone'd've",
|
||||
"someone'dve": "someone'd've",
|
||||
"someonell": "someone'll",
|
||||
"someones": "someone's",
|
||||
"somethingd": "something'd",
|
||||
"somethingd've": "something'd've",
|
||||
"something'dve": "something'd've",
|
||||
"somethingll": "something'll",
|
||||
"thats": "that's",
|
||||
"thered": "there'd",
|
||||
"thered've": "there'd've",
|
||||
"there'dve": "there'd've",
|
||||
"therere": "there're",
|
||||
"theres": "there's",
|
||||
"theyd": "they'd",
|
||||
"theyd've": "they'd've",
|
||||
"they'dve": "they'd've",
|
||||
"theyll": "they'll",
|
||||
"theyre": "they're",
|
||||
"theyve": "they've",
|
||||
"twas": "'twas",
|
||||
"wasnt": "wasn't",
|
||||
"wed've": "we'd've",
|
||||
"we'dve": "we'd've",
|
||||
"weve": "we've",
|
||||
"werent": "weren't",
|
||||
"whatll": "what'll",
|
||||
"whatre": "what're",
|
||||
"whats": "what's",
|
||||
"whatve": "what've",
|
||||
"whens": "when's",
|
||||
"whered": "where'd",
|
||||
"wheres": "where's",
|
||||
"whereve": "where've",
|
||||
"whod": "who'd",
|
||||
"whod've": "who'd've",
|
||||
"who'dve": "who'd've",
|
||||
"wholl": "who'll",
|
||||
"whos": "who's",
|
||||
"whove": "who've",
|
||||
"whyll": "why'll",
|
||||
"whyre": "why're",
|
||||
"whys": "why's",
|
||||
"wont": "won't",
|
||||
"wouldve": "would've",
|
||||
"wouldnt": "wouldn't",
|
||||
"wouldnt've": "wouldn't've",
|
||||
"wouldn'tve": "wouldn't've",
|
||||
"yall": "y'all",
|
||||
"yall'll": "y'all'll",
|
||||
"y'allll": "y'all'll",
|
||||
"yall'd've": "y'all'd've",
|
||||
"y'alld've": "y'all'd've",
|
||||
"y'all'dve": "y'all'd've",
|
||||
"youd": "you'd",
|
||||
"youd've": "you'd've",
|
||||
"you'dve": "you'd've",
|
||||
"youll": "you'll",
|
||||
"youre": "you're",
|
||||
"youve": "you've",
|
||||
"1st": "first",
|
||||
"2nd": "second",
|
||||
"3rd": "third",
|
||||
}
|
||||
NUMBERS = {
|
||||
"none": "0",
|
||||
"zero": "0",
|
||||
"one": "1",
|
||||
"two": "2",
|
||||
"three": "3",
|
||||
"four": "4",
|
||||
"five": "5",
|
||||
"six": "6",
|
||||
"seven": "7",
|
||||
"eight": "8",
|
||||
"nine": "9",
|
||||
"ten": "10",
|
||||
}
|
||||
ARTICLES = [
|
||||
"a",
|
||||
"an",
|
||||
"the",
|
||||
"to",
|
||||
"in",
|
||||
"from",
|
||||
"by",
|
||||
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
|
||||
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||||
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
|
||||
PUNCTUATION = [
|
||||
";",
|
||||
r"/",
|
||||
"[",
|
||||
"]",
|
||||
'"',
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
"=",
|
||||
"+",
|
||||
"\\",
|
||||
"_",
|
||||
"-",
|
||||
">",
|
||||
"<",
|
||||
"@",
|
||||
"`",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
]
|
||||
|
||||
|
||||
def normalize_answer(s: str) -> str:
|
||||
# process punctuation
|
||||
for p in PUNCTUATION:
|
||||
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
|
||||
s = s.replace(p, "")
|
||||
else:
|
||||
s = s.replace(p, " ")
|
||||
s = PERIOD_STRIP.sub("", s, re.UNICODE)
|
||||
|
||||
# process digits and articles
|
||||
temp_text = s.lower().split()
|
||||
out_text = []
|
||||
for word in temp_text:
|
||||
word = NUMBERS.setdefault(word, word)
|
||||
if word not in ARTICLES:
|
||||
out_text.append(word)
|
||||
|
||||
# standardize contractions
|
||||
for word_id, word in enumerate(out_text):
|
||||
if word in CONTRACTIONS:
|
||||
out_text[word_id] = CONTRACTIONS[word]
|
||||
return " ".join(out_text)
|
||||
|
||||
|
||||
class DocVQAScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
docvqa basically matches the generated answer against several allowed
|
||||
choices, but we need to normalize the answer to avoid penalizing
|
||||
trivial differences
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
docvqa.identifier: docvqa,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "docvqa",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answers = json.loads(input_row["expected_answer"])
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.equality import equality
|
||||
|
||||
|
||||
class EqualityScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
equality.identifier: equality,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "equality",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
docvqa = ScoringFn(
|
||||
identifier="basic::docvqa",
|
||||
description="DocVQA Visual Question & Answer scoring function",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="docvqa",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
equality = ScoringFn(
|
||||
identifier="basic::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
provider_id="basic",
|
||||
provider_resource_id="equality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
ifeval = ScoringFn(
|
||||
identifier="basic::ifeval",
|
||||
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="ifeval",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.weighted_average],
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||
|
||||
|
||||
regex_parser_math_response = ScoringFn(
|
||||
identifier="basic::regex_parser_math_response",
|
||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-math-response",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"The best answer is ",
|
||||
r"Answer\s*:",
|
||||
r"Answer\s*:", # Korean invisible character
|
||||
r"উত্তর\s*:",
|
||||
r"उत्तर\s*:",
|
||||
r"উত্তরঃ",
|
||||
r"উত্তর\s*:",
|
||||
r"Antwort\s*:",
|
||||
r"답변\s*:",
|
||||
r"정답\s*:",
|
||||
r"답\s*:",
|
||||
r"答案\s*:",
|
||||
r"答案\s*:",
|
||||
r"答\s*:",
|
||||
r"答\s*:",
|
||||
r"答复\s*:",
|
||||
r"答曰\s*:",
|
||||
r"الإجابة:",
|
||||
r"الجواب:",
|
||||
r"إجابة:",
|
||||
r"الإجابة النهائية:",
|
||||
r"الإجابة الصحيحة:",
|
||||
r"الإجابة الصحيحة هي:",
|
||||
r"الإجابة هي:",
|
||||
r"Respuesta\s*:",
|
||||
r"Risposta\s*:",
|
||||
r"答え\s*:",
|
||||
r"答え\s*:",
|
||||
r"回答\s*:",
|
||||
r"回答\s*:",
|
||||
r"解答\s*:",
|
||||
r"Jawaban\s*:",
|
||||
r"Réponse\s*:",
|
||||
r"Resposta\s*:",
|
||||
r"Jibu\s*:",
|
||||
r"Idahun\s*:",
|
||||
r"Ìdáhùn\s*:",
|
||||
r"Idáhùn\s*:",
|
||||
r"Àmọ̀nà\s*:",
|
||||
r"Àdáhùn\s*:",
|
||||
r"Ànúgọ\s*:",
|
||||
r"Àṣàyàn\s*:",
|
||||
]
|
||||
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||
|
||||
regex_parser_multiple_choice_answer = ScoringFn(
|
||||
identifier="basic::regex_parser_multiple_choice_answer",
|
||||
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-multiple-choice-answer",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
subset_of = ScoringFn(
|
||||
identifier="basic::subset_of",
|
||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="subset-of",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.ifeval import (
|
||||
ifeval,
|
||||
)
|
||||
|
||||
|
||||
class IfEvalScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn Instruction-Following Eval (IFEval) benchmark
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
ifeval.identifier: ifeval,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
instruction_list = input_row["instruction_id_list"]
|
||||
generated_answer = input_row["generated_answer"].strip()
|
||||
|
||||
is_following_list = []
|
||||
results = dict(
|
||||
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
|
||||
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
|
||||
)
|
||||
|
||||
for index, instruction_id in enumerate(instruction_list):
|
||||
instruction_cls = INSTRUCTION_DICT[instruction_id]
|
||||
instruction = instruction_cls(instruction_id)
|
||||
results[instruction_id + "_total"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_total"] += 1.0
|
||||
|
||||
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
|
||||
print(clean_input_row)
|
||||
instruction.build_description(**clean_input_row)
|
||||
args = instruction.get_instruction_args()
|
||||
if args and "prompt" in args:
|
||||
instruction.build_description(prompt=input_row["prompt"])
|
||||
|
||||
if generated_answer and instruction.check_following(generated_answer):
|
||||
is_following_list.append(True)
|
||||
results[instruction_id + "_correct"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_correct"] += 1.0
|
||||
else:
|
||||
is_following_list.append(False)
|
||||
|
||||
if len(is_following_list) == 0:
|
||||
return {
|
||||
"score": 0.0,
|
||||
"weight": 0.0,
|
||||
}
|
||||
|
||||
return {
|
||||
"score": float(sum(is_following_list)) / float(len(is_following_list)),
|
||||
"weight": float(len(is_following_list)),
|
||||
}
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||
regex_parser_multiple_choice_answer,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
# parse answer according to regex
|
||||
parsed_answer = None
|
||||
for regex in fn_def.params.parsing_regexes:
|
||||
match = re.search(regex, generated_answer)
|
||||
if match:
|
||||
parsed_answer = match.group(1)
|
||||
break
|
||||
|
||||
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.subset_of import subset_of
|
||||
|
||||
|
||||
class SubsetOfScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
subset_of.identifier: subset_of,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "subset_of",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer in generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
3319
src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
3319
src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
File diff suppressed because it is too large
Load diff
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue