mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-06 23:29:57 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
b52a265a51
4 changed files with 63 additions and 51 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
|
@ -2,4 +2,4 @@
|
||||||
|
|
||||||
# These owners will be the default owners for everything in
|
# These owners will be the default owners for everything in
|
||||||
# the repo. Unless a later match takes precedence,
|
# the repo. Unless a later match takes precedence,
|
||||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan
|
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ response = agent.create_turn(
|
||||||
],
|
],
|
||||||
documents=[
|
documents=[
|
||||||
{
|
{
|
||||||
"content": "https://raw.githubusercontent.com/example/doc.rst",
|
"content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
|
||||||
"mime_type": "text/plain",
|
"mime_type": "text/plain",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
if provider_vector_db_id is None:
|
if provider_vector_db_id is None:
|
||||||
provider_vector_db_id = vector_db_id
|
provider_vector_db_id = vector_db_id
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
# If provider_id not specified, use the only provider if it supports this shield type
|
if len(self.impls_by_provider_id) > 0:
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
else:
|
if len(self.impls_by_provider_id) > 1:
|
||||||
raise ValueError(
|
logger.warning(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,11 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
StepType,
|
StepType,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL, TextDelta
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
|
|
@ -27,12 +28,15 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import RunShieldResponse
|
from llama_stack.apis.safety import RunShieldResponse
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolGroupsResponse,
|
||||||
|
ListToolsResponse,
|
||||||
Tool,
|
Tool,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
|
|
@ -40,7 +44,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
|
||||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
MEMORY_QUERY_TOOL,
|
MEMORY_QUERY_TOOL,
|
||||||
)
|
)
|
||||||
|
|
@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
class MockInferenceAPI:
|
class MockInferenceAPI:
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = None,
|
tool_choice: Optional[ToolChoice] = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
async def stream_response():
|
async def stream_response():
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="start",
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="progress",
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta="AI is a fascinating field...",
|
delta=TextDelta(text="AI is a fascinating field..."),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="complete",
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason="end_of_turn",
|
stop_reason=StopReason.end_of_turn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -133,12 +138,13 @@ class MockToolGroupsAPI:
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
return []
|
return ListToolGroupsResponse(data=[])
|
||||||
|
|
||||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||||
if tool_group_id == MEMORY_TOOLGROUP:
|
if toolgroup_id == MEMORY_TOOLGROUP:
|
||||||
return [
|
return ListToolsResponse(
|
||||||
|
data=[
|
||||||
Tool(
|
Tool(
|
||||||
identifier=MEMORY_QUERY_TOOL,
|
identifier=MEMORY_QUERY_TOOL,
|
||||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||||
|
|
@ -149,8 +155,10 @@ class MockToolGroupsAPI:
|
||||||
parameters=[],
|
parameters=[],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
)
|
||||||
return [
|
if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
|
||||||
|
return ListToolsResponse(
|
||||||
|
data=[
|
||||||
Tool(
|
Tool(
|
||||||
identifier="code_interpreter",
|
identifier="code_interpreter",
|
||||||
provider_resource_id="code_interpreter",
|
provider_resource_id="code_interpreter",
|
||||||
|
|
@ -161,7 +169,8 @@ class MockToolGroupsAPI:
|
||||||
parameters=[],
|
parameters=[],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return []
|
)
|
||||||
|
return ListToolsResponse(data=[])
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
return Tool(
|
return Tool(
|
||||||
|
|
@ -174,7 +183,7 @@ class MockToolGroupsAPI:
|
||||||
parameters=[],
|
parameters=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
async def unregister_tool_group(self, toolgroup_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
|
||||||
chat_agent = await impl.get_agent(response.agent_id)
|
chat_agent = await impl.get_agent(response.agent_id)
|
||||||
|
|
||||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||||
|
tool_defs_names = [t.tool_name for t in tool_defs]
|
||||||
if expected_memory:
|
if expected_memory:
|
||||||
assert MEMORY_QUERY_TOOL in tool_defs
|
assert MEMORY_QUERY_TOOL in tool_defs_names
|
||||||
if expected_code_interpreter:
|
if expected_code_interpreter:
|
||||||
assert BuiltinTool.code_interpreter in tool_defs
|
assert BuiltinTool.code_interpreter in tool_defs_names
|
||||||
if expected_memory and expected_code_interpreter:
|
if expected_memory and expected_code_interpreter:
|
||||||
# override the tools for turn
|
# override the tools for turn
|
||||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||||
|
|
@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert MEMORY_QUERY_TOOL in new_tool_defs
|
new_tool_defs_names = [t.tool_name for t in new_tool_defs]
|
||||||
assert BuiltinTool.code_interpreter not in new_tool_defs
|
assert MEMORY_QUERY_TOOL in new_tool_defs_names
|
||||||
|
assert BuiltinTool.code_interpreter not in new_tool_defs_names
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue