Merge branch 'meta-llama:main' into main

This commit is contained in:
Francisco Arceo 2025-03-05 08:23:47 -05:00 committed by GitHub
commit b52a265a51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 63 additions and 51 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan * @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722

View file

@ -122,7 +122,7 @@ response = agent.create_turn(
], ],
documents=[ documents=[
{ {
"content": "https://raw.githubusercontent.com/example/doc.rst", "content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
"mime_type": "text/plain", "mime_type": "text/plain",
} }
], ],

View file

@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if provider_vector_db_id is None: if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id provider_vector_db_id = vector_db_id
if provider_id is None: if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) > 0:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
else: if len(self.impls_by_provider_id) > 1:
raise ValueError( logger.warning(
"No provider specified and multiple providers available. Please specify a provider_id." f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
) )
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")

View file

@ -16,10 +16,11 @@ from llama_stack.apis.agents import (
AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
StepType, StepType,
) )
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL, TextDelta
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage, CompletionMessage,
LogProbConfig, LogProbConfig,
@ -27,12 +28,15 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.safety import RunShieldResponse
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool, Tool,
ToolDef, ToolDef,
ToolGroup, ToolGroup,
@ -40,7 +44,7 @@ from llama_stack.apis.tools import (
ToolInvocationResult, ToolInvocationResult,
) )
from llama_stack.apis.vector_io import QueryChunksResponse from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL, MEMORY_QUERY_TOOL,
) )
@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MockInferenceAPI: class MockInferenceAPI:
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None, tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
async def stream_response(): async def stream_response():
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="start", event_type=ChatCompletionResponseEventType.start,
delta="", delta=TextDelta(text=""),
) )
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="progress", event_type=ChatCompletionResponseEventType.progress,
delta="AI is a fascinating field...", delta=TextDelta(text="AI is a fascinating field..."),
) )
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="complete", event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason="end_of_turn", stop_reason=StopReason.end_of_turn,
) )
) )
@ -133,12 +138,13 @@ class MockToolGroupsAPI:
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,
) )
async def list_tool_groups(self) -> List[ToolGroup]: async def list_tool_groups(self) -> ListToolGroupsResponse:
return [] return ListToolGroupsResponse(data=[])
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
if tool_group_id == MEMORY_TOOLGROUP: if toolgroup_id == MEMORY_TOOLGROUP:
return [ return ListToolsResponse(
data=[
Tool( Tool(
identifier=MEMORY_QUERY_TOOL, identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL, provider_resource_id=MEMORY_QUERY_TOOL,
@ -149,8 +155,10 @@ class MockToolGroupsAPI:
parameters=[], parameters=[],
) )
] ]
if tool_group_id == CODE_INTERPRETER_TOOLGROUP: )
return [ if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
return ListToolsResponse(
data=[
Tool( Tool(
identifier="code_interpreter", identifier="code_interpreter",
provider_resource_id="code_interpreter", provider_resource_id="code_interpreter",
@ -161,7 +169,8 @@ class MockToolGroupsAPI:
parameters=[], parameters=[],
) )
] ]
return [] )
return ListToolsResponse(data=[])
async def get_tool(self, tool_name: str) -> Tool: async def get_tool(self, tool_name: str) -> Tool:
return Tool( return Tool(
@ -174,7 +183,7 @@ class MockToolGroupsAPI:
parameters=[], parameters=[],
) )
async def unregister_tool_group(self, tool_group_id: str) -> None: async def unregister_tool_group(self, toolgroup_id: str) -> None:
pass pass
@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
chat_agent = await impl.get_agent(response.agent_id) chat_agent = await impl.get_agent(response.agent_id)
tool_defs, _ = await chat_agent._get_tool_defs() tool_defs, _ = await chat_agent._get_tool_defs()
tool_defs_names = [t.tool_name for t in tool_defs]
if expected_memory: if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs assert MEMORY_QUERY_TOOL in tool_defs_names
if expected_code_interpreter: if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs assert BuiltinTool.code_interpreter in tool_defs_names
if expected_memory and expected_code_interpreter: if expected_memory and expected_code_interpreter:
# override the tools for turn # override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs( new_tool_defs, _ = await chat_agent._get_tool_defs(
@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
) )
] ]
) )
assert MEMORY_QUERY_TOOL in new_tool_defs new_tool_defs_names = [t.tool_name for t in new_tool_defs]
assert BuiltinTool.code_interpreter not in new_tool_defs assert MEMORY_QUERY_TOOL in new_tool_defs_names
assert BuiltinTool.code_interpreter not in new_tool_defs_names