llama-stack/llama_stack/distribution/routers/routers.py
Dinesh Yeduguru a5c57cd381
agents to use tools api (#673)
# What does this PR do?

PR #639 introduced the notion of Tools API and ability to invoke tools
through API just as any resource. This PR changes the Agents to start
using the Tools API to invoke tools. Major changes include:
1) Ability to specify tool groups with AgentConfig
2) Agent gets the corresponding tool definitions for the specified tools
and pass along to the model
3) Attachements are now named as Documents and their behavior is mostly
unchanged from user perspective
4) You can specify args that can be injected to a tool call through
Agent config. This is especially useful in case of memory tool, where
you want the tool to operate on a specific memory bank.
5) You can also register tool groups with args, which lets the agent
inject these as well into the tool call.
6) All tests have been migrated to use new tools API and fixtures
including client SDK tests
7) Telemetry just works with tools API because of our trace protocol
decorator


## Test Plan
```
pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

pytest -s -v -k together  llama_stack/providers/tests/tools/test_tools.py \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py
```
run.yaml:
https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994

Notebook:
https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
2025-01-08 19:01:00 -08:00

425 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import (
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.apis.models import ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
) -> None:
await self.routing_table.register_memory_bank(
memory_bank_id,
params,
provider_id,
provider_memorybank_id,
)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds
)
async def query_documents(
self,
bank_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params
)
class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
provider = self.routing_table.get_provider_impl(model_id)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return await provider.chat_completion(**params)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in await provider.completion(**params))
else:
return await provider.completion(**params)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
)
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
return await self.routing_table.register_shield(
shield_id, provider_shield_id, provider_id, params
)
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_eval(
self,
task_id: str,
task_config: AppEvalTaskConfig,
) -> Job:
return await self.routing_table.get_provider_impl(task_id).run_eval(
task_id=task_id,
task_config=task_config,
)
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
task_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
async def job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
await self.routing_table.get_provider_impl(task_id).job_cancel(
task_id,
job_id,
)
async def job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).job_result(
task_id,
job_id,
)
class ToolRuntimeRouter(ToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
args=args,
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)