mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? - as title, cleaning up `import *`'s - upgrade tests to make them more robust to bad model outputs - remove import *'s in llama_stack/apis/* (skip __init__ modules) <img width="465" alt="image" src="https://github.com/user-attachments/assets/d8339c13-3b40-4ba5-9c53-0d2329726ee2" /> - run `sh run_openapi_generator.sh`, no types gets affected ## Test Plan ### Providers Tests **agents** ``` pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8 ``` **inference** ```bash # meta-reference torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py # together pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py ``` **safety** ``` pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B ``` **memory** ``` pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 ``` **scoring** ``` pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py ``` **datasetio** ``` pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py ``` **eval** ``` pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py ``` ### Client-SDK Tests ``` LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk ``` ### llama-stack-apps ``` PORT=5000 LOCALHOST=localhost python -m examples.agents.hello $LOCALHOST $PORT python -m examples.agents.inflation $LOCALHOST $PORT python -m examples.agents.podcast_transcript $LOCALHOST $PORT python -m examples.agents.rag_as_attachments $LOCALHOST $PORT python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT # Vision model python -m examples.interior_design_assistant.app python -m examples.agent_store.app $LOCALHOST $PORT ``` ### CLI ``` which llama llama model prompt-format -m Llama3.2-11B-Vision-Instruct llama model list llama stack list-apis llama stack list-providers inference llama stack build --template ollama --image-type conda ``` ### Distributions Tests **ollama** ``` llama stack build --template ollama --image-type conda ollama run llama3.2:1b-instruct-fp16 llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct ``` **fireworks** ``` llama stack build --template fireworks --image-type conda llama stack run ./llama_stack/templates/fireworks/run.yaml ``` **together** ``` llama stack build --template together --image-type conda llama stack run ./llama_stack/templates/together/run.yaml ``` **tgi** ``` llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
489 lines
13 KiB
Python
489 lines
13 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Union,
|
|
)
|
|
|
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
|
|
|
from llama_models.schema_utils import json_schema_type, webmethod
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
|
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
|
from llama_stack.apis.inference import (
|
|
CompletionMessage,
|
|
SamplingParams,
|
|
ToolCall,
|
|
ToolCallDelta,
|
|
ToolChoice,
|
|
ToolPromptFormat,
|
|
ToolResponse,
|
|
ToolResponseMessage,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.memory import MemoryBank
|
|
from llama_stack.apis.safety import SafetyViolation
|
|
|
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
|
|
|
|
@json_schema_type
|
|
class Attachment(BaseModel):
|
|
content: InterleavedContent | URL
|
|
mime_type: str
|
|
|
|
|
|
class AgentTool(Enum):
|
|
brave_search = "brave_search"
|
|
wolfram_alpha = "wolfram_alpha"
|
|
photogen = "photogen"
|
|
code_interpreter = "code_interpreter"
|
|
|
|
function_call = "function_call"
|
|
memory = "memory"
|
|
|
|
|
|
class ToolDefinitionCommon(BaseModel):
|
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
|
|
|
|
|
class SearchEngineType(Enum):
|
|
bing = "bing"
|
|
brave = "brave"
|
|
tavily = "tavily"
|
|
|
|
|
|
@json_schema_type
|
|
class SearchToolDefinition(ToolDefinitionCommon):
|
|
# NOTE: brave_search is just a placeholder since model always uses
|
|
# brave_search as tool call name
|
|
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
|
api_key: str
|
|
engine: SearchEngineType = SearchEngineType.brave
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
|
api_key: str
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class PhotogenToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
|
enable_inline_code_execution: bool = True
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
|
function_name: str
|
|
description: str
|
|
parameters: Dict[str, ToolParamDefinition]
|
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
|
|
|
|
class _MemoryBankConfigCommon(BaseModel):
|
|
bank_id: str
|
|
|
|
|
|
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal["vector"] = "vector"
|
|
|
|
|
|
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal["keyvalue"] = "keyvalue"
|
|
keys: List[str] # what keys to focus on
|
|
|
|
|
|
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal["keyword"] = "keyword"
|
|
|
|
|
|
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
|
type: Literal["graph"] = "graph"
|
|
entities: List[str] # what entities to focus on
|
|
|
|
|
|
MemoryBankConfig = Annotated[
|
|
Union[
|
|
AgentVectorMemoryBankConfig,
|
|
AgentKeyValueMemoryBankConfig,
|
|
AgentKeywordMemoryBankConfig,
|
|
AgentGraphMemoryBankConfig,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class MemoryQueryGenerator(Enum):
|
|
default = "default"
|
|
llm = "llm"
|
|
custom = "custom"
|
|
|
|
|
|
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
|
type: Literal[MemoryQueryGenerator.default.value] = (
|
|
MemoryQueryGenerator.default.value
|
|
)
|
|
sep: str = " "
|
|
|
|
|
|
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
|
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
|
model: str
|
|
template: str
|
|
|
|
|
|
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
|
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
|
|
|
|
|
MemoryQueryGeneratorConfig = Annotated[
|
|
Union[
|
|
DefaultMemoryQueryGeneratorConfig,
|
|
LLMMemoryQueryGeneratorConfig,
|
|
CustomMemoryQueryGeneratorConfig,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
|
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
|
# This config defines how a query is generated using the messages
|
|
# for memory bank retrieval.
|
|
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
|
default=DefaultMemoryQueryGeneratorConfig()
|
|
)
|
|
max_tokens_in_context: int = 4096
|
|
max_chunks: int = 10
|
|
|
|
|
|
AgentToolDefinition = Annotated[
|
|
Union[
|
|
SearchToolDefinition,
|
|
WolframAlphaToolDefinition,
|
|
PhotogenToolDefinition,
|
|
CodeInterpreterToolDefinition,
|
|
FunctionCallToolDefinition,
|
|
MemoryToolDefinition,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class StepCommon(BaseModel):
|
|
turn_id: str
|
|
step_id: str
|
|
started_at: Optional[datetime] = None
|
|
completed_at: Optional[datetime] = None
|
|
|
|
|
|
class StepType(Enum):
|
|
inference = "inference"
|
|
tool_execution = "tool_execution"
|
|
shield_call = "shield_call"
|
|
memory_retrieval = "memory_retrieval"
|
|
|
|
|
|
@json_schema_type
|
|
class InferenceStep(StepCommon):
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
|
model_response: CompletionMessage
|
|
|
|
|
|
@json_schema_type
|
|
class ToolExecutionStep(StepCommon):
|
|
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
|
tool_calls: List[ToolCall]
|
|
tool_responses: List[ToolResponse]
|
|
|
|
|
|
@json_schema_type
|
|
class ShieldCallStep(StepCommon):
|
|
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
|
violation: Optional[SafetyViolation]
|
|
|
|
|
|
@json_schema_type
|
|
class MemoryRetrievalStep(StepCommon):
|
|
step_type: Literal[StepType.memory_retrieval.value] = (
|
|
StepType.memory_retrieval.value
|
|
)
|
|
memory_bank_ids: List[str]
|
|
inserted_context: InterleavedContent
|
|
|
|
|
|
Step = Annotated[
|
|
Union[
|
|
InferenceStep,
|
|
ToolExecutionStep,
|
|
ShieldCallStep,
|
|
MemoryRetrievalStep,
|
|
],
|
|
Field(discriminator="step_type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class Turn(BaseModel):
|
|
"""A single turn in an interaction with an Agentic System."""
|
|
|
|
turn_id: str
|
|
session_id: str
|
|
input_messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
]
|
|
steps: List[Step]
|
|
output_message: CompletionMessage
|
|
output_attachments: List[Attachment] = Field(default_factory=list)
|
|
|
|
started_at: datetime
|
|
completed_at: Optional[datetime] = None
|
|
|
|
|
|
@json_schema_type
|
|
class Session(BaseModel):
|
|
"""A single session of an interaction with an Agentic System."""
|
|
|
|
session_id: str
|
|
session_name: str
|
|
turns: List[Turn]
|
|
started_at: datetime
|
|
|
|
memory_bank: Optional[MemoryBank] = None
|
|
|
|
|
|
class AgentConfigCommon(BaseModel):
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
|
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
|
|
|
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
default=ToolPromptFormat.json
|
|
)
|
|
|
|
max_infer_iters: int = 10
|
|
|
|
|
|
@json_schema_type
|
|
class AgentConfig(AgentConfigCommon):
|
|
model: str
|
|
instructions: str
|
|
enable_session_persistence: bool
|
|
|
|
|
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
|
instructions: Optional[str] = None
|
|
|
|
|
|
class AgentTurnResponseEventType(Enum):
|
|
step_start = "step_start"
|
|
step_complete = "step_complete"
|
|
step_progress = "step_progress"
|
|
|
|
turn_start = "turn_start"
|
|
turn_complete = "turn_complete"
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
|
|
AgentTurnResponseEventType.step_start.value
|
|
)
|
|
step_type: StepType
|
|
step_id: str
|
|
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
|
|
AgentTurnResponseEventType.step_complete.value
|
|
)
|
|
step_type: StepType
|
|
step_details: Step
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
|
|
AgentTurnResponseEventType.step_progress.value
|
|
)
|
|
step_type: StepType
|
|
step_id: str
|
|
|
|
text_delta: Optional[str] = None
|
|
tool_call_delta: Optional[ToolCallDelta] = None
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
|
|
AgentTurnResponseEventType.turn_start.value
|
|
)
|
|
turn_id: str
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
|
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
|
|
AgentTurnResponseEventType.turn_complete.value
|
|
)
|
|
turn: Turn
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseEvent(BaseModel):
|
|
"""Streamed agent execution response."""
|
|
|
|
payload: Annotated[
|
|
Union[
|
|
AgentTurnResponseStepStartPayload,
|
|
AgentTurnResponseStepProgressPayload,
|
|
AgentTurnResponseStepCompletePayload,
|
|
AgentTurnResponseTurnStartPayload,
|
|
AgentTurnResponseTurnCompletePayload,
|
|
],
|
|
Field(discriminator="event_type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class AgentCreateResponse(BaseModel):
|
|
agent_id: str
|
|
|
|
|
|
@json_schema_type
|
|
class AgentSessionCreateResponse(BaseModel):
|
|
session_id: str
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|
agent_id: str
|
|
session_id: str
|
|
|
|
# TODO: figure out how we can simplify this and make why
|
|
# ToolResponseMessage needs to be here (it is function call
|
|
# execution from outside the system)
|
|
messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
]
|
|
attachments: Optional[List[Attachment]] = None
|
|
|
|
stream: Optional[bool] = False
|
|
|
|
|
|
@json_schema_type
|
|
class AgentTurnResponseStreamChunk(BaseModel):
|
|
"""streamed agent turn completion response."""
|
|
|
|
event: AgentTurnResponseEvent
|
|
|
|
|
|
@json_schema_type
|
|
class AgentStepResponse(BaseModel):
|
|
step: Step
|
|
|
|
|
|
@runtime_checkable
|
|
@trace_protocol
|
|
class Agents(Protocol):
|
|
@webmethod(route="/agents/create")
|
|
async def create_agent(
|
|
self,
|
|
agent_config: AgentConfig,
|
|
) -> AgentCreateResponse: ...
|
|
|
|
@webmethod(route="/agents/turn/create")
|
|
async def create_agent_turn(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
messages: List[
|
|
Union[
|
|
UserMessage,
|
|
ToolResponseMessage,
|
|
]
|
|
],
|
|
attachments: Optional[List[Attachment]] = None,
|
|
stream: Optional[bool] = False,
|
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
|
|
|
@webmethod(route="/agents/turn/get")
|
|
async def get_agents_turn(
|
|
self, agent_id: str, session_id: str, turn_id: str
|
|
) -> Turn: ...
|
|
|
|
@webmethod(route="/agents/step/get")
|
|
async def get_agents_step(
|
|
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
|
) -> AgentStepResponse: ...
|
|
|
|
@webmethod(route="/agents/session/create")
|
|
async def create_agent_session(
|
|
self,
|
|
agent_id: str,
|
|
session_name: str,
|
|
) -> AgentSessionCreateResponse: ...
|
|
|
|
@webmethod(route="/agents/session/get")
|
|
async def get_agents_session(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str,
|
|
turn_ids: Optional[List[str]] = None,
|
|
) -> Session: ...
|
|
|
|
@webmethod(route="/agents/session/delete")
|
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
|
|
|
|
@webmethod(route="/agents/delete")
|
|
async def delete_agents(
|
|
self,
|
|
agent_id: str,
|
|
) -> None: ...
|