mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 12:37:33 +00:00
Codemod from llama_toolchain -> llama_stack
- added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis.<api> import foo should work now - update imports to do llama_stack.apis.<api> - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent
This commit is contained in:
parent
2cf731faea
commit
76b354a081
128 changed files with 381 additions and 376 deletions
7
llama_stack/apis/agents/__init__.py
Normal file
7
llama_stack/apis/agents/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .agents import * # noqa: F401 F403
|
457
llama_stack/apis/agents/agents.py
Normal file
457
llama_stack/apis/agents/agents.py
Normal file
|
@ -0,0 +1,457 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.common.deployment_types import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Attachment(BaseModel):
|
||||
content: InterleavedTextMedia | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class AgentTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
function_call = "function_call"
|
||||
memory = "memory"
|
||||
|
||||
|
||||
class ToolDefinitionCommon(BaseModel):
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
bing = "bing"
|
||||
brave = "brave"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SearchToolDefinition(ToolDefinitionCommon):
|
||||
# NOTE: brave_search is just a placeholder since model always uses
|
||||
# brave_search as tool call name
|
||||
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
||||
engine: SearchEngineType = SearchEngineType.brave
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
||||
enable_inline_code_execution: bool = True
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
||||
function_name: str
|
||||
description: str
|
||||
parameters: Dict[str, ToolParamDefinition]
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
|
||||
|
||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
AgentVectorMemoryBankConfig,
|
||||
AgentKeyValueMemoryBankConfig,
|
||||
AgentKeywordMemoryBankConfig,
|
||||
AgentGraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 10
|
||||
|
||||
|
||||
AgentToolDefinition = Annotated[
|
||||
Union[
|
||||
SearchToolDefinition,
|
||||
WolframAlphaToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
MemoryToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class StepType(Enum):
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
response: ShieldResponse
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
||||
StepType.memory_retrieval.value
|
||||
)
|
||||
memory_bank_ids: List[str]
|
||||
inserted_context: InterleavedTextMedia
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
Union[
|
||||
InferenceStep,
|
||||
ToolExecutionStep,
|
||||
ShieldCallStep,
|
||||
MemoryRetrievalStep,
|
||||
],
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Turn(BaseModel):
|
||||
"""A single turn in an interaction with an Agentic System."""
|
||||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
steps: List[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: List[Attachment] = Field(default_factory=list)
|
||||
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Session(BaseModel):
|
||||
"""A single session of an interaction with an Agentic System."""
|
||||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: List[Turn]
|
||||
started_at: datetime
|
||||
|
||||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
model: str
|
||||
instructions: str
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(Enum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
||||
turn_start = "turn_start"
|
||||
turn_complete = "turn_complete"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
|
||||
AgentTurnResponseEventType.step_start.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
|
||||
AgentTurnResponseEventType.step_complete.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_details: Step
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
|
||||
AgentTurnResponseEventType.step_progress.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
model_response_text_delta: Optional[str] = None
|
||||
tool_call_delta: Optional[ToolCallDelta] = None
|
||||
tool_response_text_delta: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
|
||||
AgentTurnResponseEventType.turn_start.value
|
||||
)
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
|
||||
AgentTurnResponseEventType.turn_complete.value
|
||||
)
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
"""Streamed agent execution response."""
|
||||
|
||||
payload: Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCreateResponse(BaseModel):
|
||||
agent_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentSessionCreateResponse(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
attachments: Optional[List[Attachment]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentStepResponse(BaseModel):
|
||||
step: Step
|
||||
|
||||
|
||||
class Agents(Protocol):
|
||||
@webmethod(route="/agents/create")
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agents/turn/create")
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AgentTurnResponseStreamChunk: ...
|
||||
|
||||
@webmethod(route="/agents/turn/get")
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
|
||||
@webmethod(route="/agents/step/get")
|
||||
async def get_agents_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
) -> AgentStepResponse: ...
|
||||
|
||||
@webmethod(route="/agents/session/create")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agents/session/get")
|
||||
async def get_agents_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
|
||||
@webmethod(route="/agents/session/delete")
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/agents/delete")
|
||||
async def delete_agents(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
210
llama_stack/apis/agents/client.py
Normal file
210
llama_stack/apis/agents/client.py
Normal file
|
@ -0,0 +1,210 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import fire
|
||||
|
||||
import httpx
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||
|
||||
from .agents import * # noqa: F403
|
||||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
||||
return AgentsClient(config.url)
|
||||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
return json.loads(d.json())
|
||||
|
||||
|
||||
class AgentsClient(Agents):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agents/create",
|
||||
json={
|
||||
"agent_config": encodable_dict(agent_config),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return AgentCreateResponse(**response.json())
|
||||
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agents/session/create",
|
||||
json={
|
||||
"agent_id": agent_id,
|
||||
"session_name": session_name,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return AgentSessionCreateResponse(**response.json())
|
||||
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/agents/turn/create",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
jdata = json.loads(data)
|
||||
if "error" in jdata:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield AgentTurnResponseStreamChunk(**jdata)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||
agent_config = AgentConfig(
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
||||
tools=tool_definitions,
|
||||
tool_choice=ToolChoice.auto,
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
|
||||
create_response = await api.create_agent(agent_config)
|
||||
session_response = await api.create_agent_session(
|
||||
agent_id=create_response.agent_id,
|
||||
session_name="test_session",
|
||||
)
|
||||
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||
iterator = api.create_agent_turn(
|
||||
AgentTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
session_id=session_response.session_id,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
)
|
||||
|
||||
async for event, log in EventLogger().log(iterator):
|
||||
if log is not None:
|
||||
log.print()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
SearchToolDefinition(engine=SearchEngineType.bing),
|
||||
WolframAlphaToolDefinition(),
|
||||
CodeInterpreterToolDefinition(),
|
||||
]
|
||||
tool_definitions += [
|
||||
FunctionCallToolDefinition(
|
||||
function_name="get_boiling_point",
|
||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
parameters={
|
||||
"liquid_name": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
user_prompts = [
|
||||
"Who are you?",
|
||||
"what is the 100th prime number?",
|
||||
"Search web for who was 44th President of USA?",
|
||||
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
||||
"What is the boiling point of polyjuicepotion ?",
|
||||
]
|
||||
await _run_agent(api, tool_definitions, user_prompts)
|
||||
|
||||
|
||||
async def run_rag(host: str, port: int):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
attachments = [
|
||||
Attachment(
|
||||
content=URL(
|
||||
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||
),
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
# Alternatively, you can pre-populate the memory bank with documents for example,
|
||||
# using `llama_stack.memory.client`. Then you can grab the bank_id
|
||||
# from the output of that run.
|
||||
tool_definitions = [
|
||||
MemoryToolDefinition(
|
||||
max_tokens_in_context=2048,
|
||||
memory_bank_configs=[],
|
||||
),
|
||||
]
|
||||
|
||||
user_prompts = [
|
||||
"How do I use Lora?",
|
||||
"Tell me briefly about llama3 and torchtune",
|
||||
]
|
||||
|
||||
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
||||
|
||||
|
||||
def main(host: str, port: int, rag: bool = False):
|
||||
fn = run_rag if rag else run_main
|
||||
asyncio.run(fn(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
Loading…
Add table
Add a link
Reference in a new issue