mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 18:34:30 +00:00
Merge branch 'refs/heads/main' into preprocessors
# Conflicts: # llama_stack/distribution/routers/routers.py # llama_stack/templates/ollama/build.yaml # llama_stack/templates/ollama/run-with-safety.yaml # llama_stack/templates/ollama/run.yaml # llama_stack/templates/remote-vllm/build.yaml # llama_stack/templates/remote-vllm/run-with-safety.yaml # llama_stack/templates/remote-vllm/run.yaml # llama_stack/templates/together/build.yaml # llama_stack/templates/together/run-with-safety.yaml # llama_stack/templates/together/run.yaml
This commit is contained in:
commit
6b9f673fdb
313 changed files with 181388 additions and 7064 deletions
|
|
@ -6,18 +6,18 @@
|
|||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroup,
|
||||
|
|
@ -31,7 +31,6 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResumeRequest,
|
||||
Attachment,
|
||||
Document,
|
||||
|
|
@ -79,8 +78,6 @@ from llama_stack.providers.utils.telemetry import tracing
|
|||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
|
@ -186,115 +183,61 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
toolgroups_for_turn=request.toolgroups,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls and request.allow_turn_resume:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
async def _run_turn(
|
||||
self,
|
||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
||||
turn_id: Optional[str] = None,
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
if len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
is_resume = isinstance(request, AgentTurnResumeRequest)
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.tool_responses)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
if is_resume and len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
if isinstance(request.tool_responses[0], ToolResponseMessage):
|
||||
tool_response_messages = request.tool_responses
|
||||
tool_responses = [
|
||||
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
else:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
tool_responses = request.tool_responses
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
last_turn_messages = [
|
||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
last_turn_messages.extend(tool_response_messages)
|
||||
|
||||
# TODO: figure out whether we should add the tool responses to the last turn messages
|
||||
last_turn_messages.extend(request.tool_responses)
|
||||
|
||||
# get the steps from the turn id
|
||||
steps = []
|
||||
steps = turns[-1].steps
|
||||
# get steps from the turn
|
||||
steps = last_turn.steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||
|
|
@ -307,14 +250,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=x.call_id,
|
||||
tool_name=x.tool_name,
|
||||
content=x.content,
|
||||
)
|
||||
for x in request.tool_responses
|
||||
],
|
||||
tool_responses=tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
|
@ -328,62 +264,67 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn_messages
|
||||
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=request.turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
last_turn_start_time = datetime.now().astimezone().isoformat()
|
||||
if len(turns) > 0:
|
||||
last_turn_start_time = turns[-1].started_at
|
||||
|
||||
turn = Turn(
|
||||
turn_id=request.turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=last_turn_messages,
|
||||
output_message=output_message,
|
||||
started_at=last_turn_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
@ -533,9 +474,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info and session_info.vector_db_id:
|
||||
if RAG_TOOL_GROUP not in toolgroup_args:
|
||||
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||
else:
|
||||
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
|
|
@ -622,6 +572,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
||||
n_iter += 1
|
||||
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
|
||||
# If tool calls are parsed successfully,
|
||||
|
|
@ -656,12 +609,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
log.info("Done with MAX iterations, exiting.")
|
||||
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
message.stop_reason = StopReason.end_of_turn
|
||||
yield message
|
||||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
log.info("Out of token budget, exiting.")
|
||||
logcat.info("agents", "out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
|
@ -675,10 +631,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
log.info(f"Partial message: {str(message)}")
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message with EOM (iter: {n_iter}): {str(message)}",
|
||||
)
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
log.info(f"{str(message)}")
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message (iter: {n_iter}) from the model: {str(message)}",
|
||||
)
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -706,6 +668,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
message.stop_reason = StopReason.end_of_message
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
|
|
@ -791,24 +756,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
n_iter += 1
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
else {
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
for toolgroup in toolgroups_for_turn
|
||||
}
|
||||
)
|
||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
agent_config_toolgroups = []
|
||||
for toolgroup in tool_groups_to_include:
|
||||
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
||||
if name not in agent_config_toolgroups:
|
||||
agent_config_toolgroups.append(name)
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
|
|
@ -831,9 +788,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if not tools.data:
|
||||
|
|
@ -1029,7 +983,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
log.info(f"Downloading {url} -> {filepath}")
|
||||
logcat.info("agents", f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
|
|
@ -1069,6 +1023,7 @@ async def execute_tool_call_maybe(
|
|||
else:
|
||||
name = name.value
|
||||
|
||||
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
|
|
@ -1078,6 +1033,7 @@ async def execute_tool_call_maybe(
|
|||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logcat.debug("agents", f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from llama_stack.apis.agents import (
|
|||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
|
@ -140,7 +141,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
|
@ -150,7 +150,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
allow_turn_resume=allow_turn_resume,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
@ -170,7 +169,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
|
|
|||
|
|
@ -105,3 +105,15 @@ class AgentPersistence:
|
|||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
|
|
|||
|
|
@ -1,400 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
StepType,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolHost,
|
||||
ToolInvocationResult,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MockInferenceAPI:
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
async def stream_response():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="start",
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="progress",
|
||||
delta="AI is a fascinating field...",
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
delta="",
|
||||
stop_reason="end_of_turn",
|
||||
)
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_response()
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant",
|
||||
content="Mock response",
|
||||
stop_reason="end_of_turn",
|
||||
),
|
||||
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||
)
|
||||
|
||||
|
||||
class MockSafetyAPI:
|
||||
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
|
||||
class MockVectorIOAPI:
|
||||
def __init__(self):
|
||||
self.chunks = {}
|
||||
|
||||
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None):
|
||||
for chunk in chunks:
|
||||
metadata = chunk.metadata
|
||||
self.chunks[vector_db_id][metadata["document_id"]] = chunk
|
||||
|
||||
async def query_chunks(self, vector_db_id, query, params=None):
|
||||
if vector_db_id not in self.chunks:
|
||||
raise ValueError(f"Bank {vector_db_id} not found")
|
||||
|
||||
chunks = list(self.chunks[vector_db_id].values())
|
||||
scores = [1.0] * len(chunks)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class MockToolGroupsAPI:
|
||||
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
|
||||
pass
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return ToolGroup(
|
||||
identifier=toolgroup_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
)
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return []
|
||||
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
if tool_group_id == MEMORY_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
toolgroup_id=MEMORY_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::rag",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::code_interpreter",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return Tool(
|
||||
identifier=tool_name,
|
||||
provider_resource_id=tool_name,
|
||||
toolgroup_id="mock_group",
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="mock_provider",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockToolRuntimeAPI:
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return []
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
return MockInferenceAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
return MockSafetyAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_io_api():
|
||||
return MockVectorIOAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_groups_api():
|
||||
return MockToolGroupsAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_runtime_api():
|
||||
return MockToolRuntimeAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_agents_impl(
|
||||
mock_inference_api,
|
||||
mock_safety_api,
|
||||
mock_vector_io_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_tool_groups_api,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_name=sqlite_file.name,
|
||||
),
|
||||
),
|
||||
inference_api=mock_inference_api,
|
||||
safety_api=mock_safety_api,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent(get_agents_impl):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=[],
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
MEMORY_TOOLGROUP = "builtin::rag"
|
||||
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||
impl = await get_agents_impl
|
||||
toolgroups = request.param
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Hello")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
assert len(responses) > 0
|
||||
assert (
|
||||
len(responses) == 7
|
||||
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||
assert responses[0].event.payload.turn_id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = ["test_shield"]
|
||||
|
||||
responses = [
|
||||
chunk
|
||||
async for chunk in chat_agent.run_multiple_shields_wrapper(
|
||||
turn_id="test_turn_id",
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
touchpoint="user-input",
|
||||
)
|
||||
]
|
||||
|
||||
assert len(responses) == 2 # StepStart, StepComplete
|
||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||
assert not responses[1].event.payload.step_details.violation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
assert len(responses) > 0
|
||||
|
||||
step_types = [
|
||||
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
|
||||
]
|
||||
|
||||
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||
assert StepType.inference in step_types, "Inference step is missing"
|
||||
|
||||
event_types = [
|
||||
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
|
||||
]
|
||||
assert "turn_start" in event_types, "Start event is missing"
|
||||
assert "turn_complete" in event_types, "Complete event is missing"
|
||||
|
||||
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
|
||||
"Turn complete event is missing"
|
||||
)
|
||||
turn_complete_payload = next(
|
||||
response.event.payload
|
||||
for response in responses
|
||||
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
)
|
||||
turn = turn_complete_payload.turn
|
||||
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"toolgroups, expected_memory, expected_code_interpreter",
|
||||
[
|
||||
([], False, False), # no tools
|
||||
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||
],
|
||||
)
|
||||
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
chat_agent = await impl.get_agent(response.agent_id)
|
||||
|
||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||
if expected_memory:
|
||||
assert MEMORY_QUERY_TOOL in tool_defs
|
||||
if expected_code_interpreter:
|
||||
assert BuiltinTool.code_interpreter in tool_defs
|
||||
if expected_memory and expected_code_interpreter:
|
||||
# override the tools for turn
|
||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||
toolgroups_for_turn=[
|
||||
AgentToolGroupWithArgs(
|
||||
name=MEMORY_TOOLGROUP,
|
||||
args={"vector_dbs": ["test_vector_db"]},
|
||||
)
|
||||
]
|
||||
)
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
|
@ -82,23 +83,22 @@ class MetaReferenceEvalImpl(
|
|||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
|
||||
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
|
|
@ -108,16 +108,16 @@ class MetaReferenceEvalImpl(
|
|||
return Job(job_id=job_id)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate = task_config.eval_candidate
|
||||
candidate = benchmark_config.eval_candidate
|
||||
create_response = await self.agents_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
generations = []
|
||||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
|
|
@ -151,15 +151,15 @@ class MetaReferenceEvalImpl(
|
|||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate = task_config.eval_candidate
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
response = await self.inference_api.completion(
|
||||
model=candidate.model,
|
||||
content=input_content,
|
||||
|
|
@ -167,9 +167,8 @@ class MetaReferenceEvalImpl(
|
|||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = eval(chat_completion_input_str)
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
|
|
@ -190,13 +189,13 @@ class MetaReferenceEvalImpl(
|
|||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = task_config.eval_candidate
|
||||
candidate = benchmark_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
generations = await self._run_agent_generation(input_rows, task_config)
|
||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||
elif candidate.type == "model":
|
||||
generations = await self._run_model_generation(input_rows, task_config)
|
||||
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||
|
||||
|
|
@ -205,9 +204,9 @@ class MetaReferenceEvalImpl(
|
|||
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||
]
|
||||
|
||||
if task_config.scoring_params is not None:
|
||||
if benchmark_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
|
||||
class TokenResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
@ -55,7 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .llama3.generation import Llama3
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
|
|||
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config, model_id, llama_model)
|
||||
self.generator = Llama3.build(self.config, model_id, llama_model)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
|
@ -111,7 +111,7 @@ class MetaReferenceInferenceImpl(
|
|||
)
|
||||
if llama_model is None:
|
||||
raise ValueError(
|
||||
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
||||
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
||||
)
|
||||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
|
|
@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
|
@ -208,7 +210,6 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
|
@ -245,7 +246,7 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -254,6 +255,8 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationArgs:
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "scheme":
|
||||
setattr(self, k, QuantizationScheme(v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAArgs:
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
# vision model params
|
||||
vision_chunk_size: int = -1 # image resolution for image models
|
||||
vision_max_num_chunks: int = 4
|
||||
vision_num_cross_attention_layers: int = -1
|
||||
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "lora_args":
|
||||
setattr(self, k, LoRAArgs(**v))
|
||||
elif k == "quantization_args":
|
||||
setattr(self, k, QuantizationArgs(**v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
if self.n_kv_heads is None:
|
||||
self.n_kv_heads = self.n_heads
|
||||
assert self.n_kv_heads <= self.n_heads
|
||||
assert self.n_heads % self.n_kv_heads == 0
|
||||
assert self.dim % self.n_heads == 0
|
||||
|
|
@ -23,15 +23,7 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
CrossAttentionTransformer,
|
||||
)
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
|
|
@ -39,46 +31,30 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
)
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from ..common import TokenResult, model_checkpoint_dir
|
||||
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .args import ModelArgs
|
||||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
||||
class TokenResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
class Llama:
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
def build(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
|
|
@ -170,7 +146,7 @@ class Llama:
|
|||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from .quantization.loader import convert_to_fp8_quantized_model
|
||||
from ..quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
|
|
@ -183,7 +159,7 @@ class Llama:
|
|||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from .quantization.loader import convert_to_int4_quantized_model
|
||||
from ..quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
model = Transformer(model_args)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
|
|
@ -193,7 +169,7 @@ class Llama:
|
|||
# Add a wrapper for adding hadamard transform for spinquant.
|
||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||
# loading the state dict.
|
||||
from .quantization.hadamard_utils import (
|
||||
from ..quantization.hadamard_utils import (
|
||||
add_hadamard_transform_for_spinquant,
|
||||
)
|
||||
|
||||
|
|
@ -222,7 +198,7 @@ class Llama:
|
|||
model.to(device)
|
||||
|
||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args, llama_model_id)
|
||||
return Llama3(model, tokenizer, model_args, llama_model_id)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .args import ModelArgs
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
wavelen = 2 * torch.pi / freqs
|
||||
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
return torch.where(
|
||||
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
||||
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
|
||||
new_freqs,
|
||||
)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
|
||||
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if mask is not None:
|
||||
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=4 * args.dim,
|
||||
multiple_of=args.multiple_of,
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, params: ModelArgs):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.vocab_size = params.vocab_size
|
||||
self.n_layers = params.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(params.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, params))
|
||||
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
params.dim // params.n_heads,
|
||||
params.max_seq_len * 2,
|
||||
params.rope_theta,
|
||||
params.use_scaled_rope,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
mask = None
|
||||
if seqlen > 1:
|
||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if mask.device.type == torch.device("mps").type:
|
||||
mask = torch.nan_to_num(mask, nan=0.0)
|
||||
|
||||
# When performing key-value caching, we compute the attention scores
|
||||
# only for the new sequence. Thus, the matrix of scores is of size
|
||||
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
return output
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||
"""
|
||||
Resize position embedding for vision encoder.
|
||||
Original position embedding is [n_tiles * n_tiles + 1, dim]
|
||||
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
|
||||
"""
|
||||
new_grid_size = to_2tuple(grid_size)
|
||||
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
|
||||
|
||||
new_pos_emb_tok, new_pos_emb_img = (
|
||||
orig_pos_embed[:1],
|
||||
orig_pos_embed[1:],
|
||||
)
|
||||
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
|
||||
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
|
||||
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
|
||||
return new_pos_embed
|
||||
|
||||
|
||||
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a local position embedding for vision encoder and uses it
|
||||
to initialize the global position embedding.
|
||||
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
pos_embed = pos_and_cls_embed[1:]
|
||||
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
|
||||
grid_size = to_2tuple(grid_size)
|
||||
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
|
||||
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
|
||||
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
|
||||
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a global position embedding for vision encoder and resizes it to new size.
|
||||
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
# first remove cls token
|
||||
pos_embed = pos_and_cls_embed[:, :, 1:]
|
||||
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
|
||||
|
||||
xs_old, ys_old, ntok, dim = pos_embed.shape
|
||||
old_grid_size = int(math.sqrt(ntok))
|
||||
|
||||
# move to correct form for interpolation
|
||||
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
|
||||
pos_embed = pos_embed.unsqueeze(0)
|
||||
|
||||
# interpolate
|
||||
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
|
||||
pos_embed = pos_embed.permute(0, 3, 1, 2)
|
||||
pos_embed_resized = F.interpolate(
|
||||
pos_embed,
|
||||
size=new_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
|
||||
|
||||
# move it back in place
|
||||
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
|
||||
|
||||
# interpolate cls token
|
||||
cls_embed = cls_embed.permute(2, 3, 0, 1)
|
||||
cls_embed_resized = F.interpolate(
|
||||
cls_embed,
|
||||
size=(x_scale, y_scale),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
|
||||
# add cls token back in
|
||||
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
|
||||
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def build_encoder_attention_mask(
|
||||
x: torch.Tensor,
|
||||
ar: torch.Tensor,
|
||||
ntok: int,
|
||||
num_chunks: int,
|
||||
n_heads: int,
|
||||
):
|
||||
"""
|
||||
Build vision encoder attention mask that omits padding tokens.
|
||||
"""
|
||||
masks = []
|
||||
for arx in ar:
|
||||
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
||||
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
||||
mask_i = mask_i.unsqueeze(0)
|
||||
masks.append(mask_i)
|
||||
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
|
||||
return masks
|
||||
|
||||
|
||||
def expand_num_tokens_to_mult8(x):
|
||||
num_pad_tokens = 8 - (x.shape[-2] % 8)
|
||||
if num_pad_tokens == 0:
|
||||
return x, 0
|
||||
else:
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
x,
|
||||
torch.zeros(
|
||||
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
),
|
||||
],
|
||||
dim=-2,
|
||||
),
|
||||
num_pad_tokens,
|
||||
)
|
||||
|
||||
|
||||
def contract_num_tokens_from_mult8(x, num_pad_tokens):
|
||||
if num_pad_tokens == 0:
|
||||
return x
|
||||
return x[:, :, :-num_pad_tokens]
|
||||
|
|
@ -0,0 +1,408 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Any, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
IMAGE_RES = 224
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
logger.info(f"VariableSizeImageTransform size: {self.size}")
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
self._std = (0.26862954, 0.26130258, 0.27577711)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, depth in value:
|
||||
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(new_size_without_distortion[1], new_size_without_distortion[0]),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_negative_inf_value(dtype):
|
||||
return torch.finfo(dtype).min
|
||||
|
||||
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
|
@ -9,18 +9,18 @@ from copy import deepcopy
|
|||
from functools import partial
|
||||
from typing import Any, Generator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .llama3.generation import Llama3
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ def init_model_cb(
|
|||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
llama = Llama.build(config, model_id, llama_model)
|
||||
llama = Llama3.build(config, model_id, llama_model)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .generation import TokenResult
|
||||
from .common import TokenResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
|||
return parse_message(maybe_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
|
|||
if isinstance(obj, TaskResponse):
|
||||
yield obj.result
|
||||
|
||||
except GeneratorExit as e:
|
||||
except GeneratorExit:
|
||||
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||
while True:
|
||||
obj_json = self.request_socket.send()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
# The file gets a special treatment for now?
|
||||
# ruff: noqa: N803
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ import torch
|
|||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from torch import Tensor, nn
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
|
|
@ -24,6 +22,8 @@ from llama_stack.apis.inference import QuantizationType
|
|||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
from ...llama3.args import ModelArgs
|
||||
from ...llama3.model import Transformer, TransformerBlock
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock
|
||||
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
||||
quantize_fp8,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ NPROC=$7
|
|||
|
||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-stack" \
|
||||
torchrun \
|
||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||
--rdzv_id=$RUN_ID \
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import os
|
|||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
|
|
@ -36,6 +35,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
|
|
@ -143,7 +143,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -154,7 +154,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -163,6 +163,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
|
|
|
|||
|
|
@ -10,16 +10,19 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
||||
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
def llama_stack_instruct_to_torchtune_instruct(
|
||||
sample: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
|
||||
"Invalid input row"
|
||||
)
|
||||
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
||||
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
|
||||
|
||||
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
|
||||
input_message = input_messages[0]
|
||||
|
|
@ -37,7 +40,7 @@ def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Map
|
|||
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
assert ColumnName.dialog.value in sample, "Invalid input row"
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
dialog = eval(str(sample[ColumnName.dialog.value]))
|
||||
dialog = json.loads(sample[ColumnName.dialog.value])
|
||||
|
||||
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
||||
roles = []
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
self.adapter_params = get_adapter_params(model)
|
||||
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
|
||||
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
|
||||
|
||||
set_trainable_params(model, self.adapter_params)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
)
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"The best answer is ",
|
||||
r"Answer\s*:",
|
||||
r"Answer\s*:", # Korean invisible character
|
||||
r"উত্তর\s*:",
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ class BraintrustScoringImpl(
|
|||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
from .config import LlmAsJudgeScoringConfig
|
||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||
|
||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
LLM_JUDGE_FN = LlmAsJudgeScoringFn
|
||||
|
||||
|
||||
class LlmAsJudgeScoringImpl(
|
||||
|
|
@ -43,23 +43,17 @@ class LlmAsJudgeScoringImpl(
|
|||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.inference_api = inference_api
|
||||
self.scoring_fn_id_impls = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for fn in LLM_JUDGE_FNS:
|
||||
impl = fn(inference_api=self.inference_api)
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
self.llm_as_judge_fn = impl
|
||||
impl = LLM_JUDGE_FN(inference_api=self.inference_api)
|
||||
self.llm_as_judge_fn = impl
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
||||
assert f.identifier.startswith("llm-as-judge"), (
|
||||
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||
)
|
||||
|
|
@ -67,7 +61,7 @@ class LlmAsJudgeScoringImpl(
|
|||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
|
|
@ -102,9 +96,7 @@ class LlmAsJudgeScoringImpl(
|
|||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn = self.llm_as_judge_fn
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
|
@ -58,10 +58,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
judge_response = await self.inference_api.chat_completion(
|
||||
model_id=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": judge_input_msg,
|
||||
}
|
||||
UserMessage(
|
||||
content=judge_input_msg,
|
||||
),
|
||||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ class TelemetryConfig(BaseModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/" + __distro_dir__ + "/" + db_name + "}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
self._local.conn = sqlite3.connect(self.conn_string)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to SQLite database: {e}")
|
||||
raise e
|
||||
raise
|
||||
return self._local.conn
|
||||
|
||||
def setup_database(self):
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
|
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
|||
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": "${env.MILVUS_DB_PATH}"}
|
||||
|
|
@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue