Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent
This commit is contained in:
Ashwin Bharambe 2024-09-16 17:34:07 -07:00
parent 2cf731faea
commit 76b354a081
128 changed files with 381 additions and 376 deletions

View file

@ -482,7 +482,7 @@ Once the server is setup, we can test it with a client to see the example output
cd /path/to/llama-stack cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work conda activate <env> # any environment containing the llama-toolchain pip package will work
python -m llama_stack.inference.client localhost 5000 python -m llama_stack.apis.inference.client localhost 5000
``` ```
This will run the chat completion client and query the distributions /inference/chat_completion API. This will run the chat completion client and query the distributions /inference/chat_completion API.

View file

@ -296,7 +296,7 @@ Once the server is setup, we can test it with a client to see the example output
cd /path/to/llama-stack cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work conda activate <env> # any environment containing the llama-toolchain pip package will work
python -m llama_stack.inference.client localhost 5000 python -m llama_stack.apis.inference.client localhost 5000
``` ```
This will run the chat completion client and query the distributions /inference/chat_completion API. This will run the chat completion client and query the distributions /inference/chat_completion API.
@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by: Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
``` ```
python -m llama_stack.safety.client localhost 5000 python -m llama_stack.apis.safety.client localhost 5000
``` ```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import * # noqa: F401 F403

View file

@ -15,9 +15,9 @@ from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.common.deployment_types import * # noqa: F403 from llama_stack.common.deployment_types import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.safety.api import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.memory.api import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
@json_schema_type @json_schema_type
@ -26,7 +26,7 @@ class Attachment(BaseModel):
mime_type: str mime_type: str
class AgenticSystemTool(Enum): class AgentTool(Enum):
brave_search = "brave_search" brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha" wolfram_alpha = "wolfram_alpha"
photogen = "photogen" photogen = "photogen"
@ -50,41 +50,33 @@ class SearchEngineType(Enum):
class SearchToolDefinition(ToolDefinitionCommon): class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses # NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name # brave_search as tool call name
type: Literal[AgenticSystemTool.brave_search.value] = ( type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
AgenticSystemTool.brave_search.value
)
engine: SearchEngineType = SearchEngineType.brave engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon): class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.wolfram_alpha.value] = ( type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
AgenticSystemTool.wolfram_alpha.value
)
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon): class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon): class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.code_interpreter.value] = ( type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
AgenticSystemTool.code_interpreter.value
)
enable_inline_code_execution: bool = True enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon): class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.function_call.value] = ( type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
AgenticSystemTool.function_call.value
)
function_name: str function_name: str
description: str description: str
parameters: Dict[str, ToolParamDefinition] parameters: Dict[str, ToolParamDefinition]
@ -95,30 +87,30 @@ class _MemoryBankConfigCommon(BaseModel):
bank_id: str bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon): class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon): class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon): class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon): class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[ MemoryBankConfig = Annotated[
Union[ Union[
AgenticSystemVectorMemoryBankConfig, AgentVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig, AgentKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig, AgentKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig, AgentGraphMemoryBankConfig,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
@ -158,7 +150,7 @@ MemoryQueryGeneratorConfig = Annotated[
class MemoryToolDefinition(ToolDefinitionCommon): class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages # This config defines how a query is generated using the messages
# for memory bank retrieval. # for memory bank retrieval.
@ -169,7 +161,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
max_chunks: int = 10 max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[ AgentToolDefinition = Annotated[
Union[ Union[
SearchToolDefinition, SearchToolDefinition,
WolframAlphaToolDefinition, WolframAlphaToolDefinition,
@ -275,7 +267,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list) tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
@ -292,7 +284,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum): class AgentTurnResponseEventType(Enum):
step_start = "step_start" step_start = "step_start"
step_complete = "step_complete" step_complete = "step_complete"
step_progress = "step_progress" step_progress = "step_progress"
@ -302,9 +294,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel): class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = ( event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value AgentTurnResponseEventType.step_start.value
) )
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -312,20 +304,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel): class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = ( event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value AgentTurnResponseEventType.step_complete.value
) )
step_type: StepType step_type: StepType
step_details: Step step_details: Step
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel): class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value AgentTurnResponseEventType.step_progress.value
) )
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -336,49 +328,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel): class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value AgentTurnResponseEventType.turn_start.value
) )
turn_id: str turn_id: str
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value AgentTurnResponseEventType.turn_complete.value
) )
turn: Turn turn: Turn
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel): class AgentTurnResponseEvent(BaseModel):
"""Streamed agent execution response.""" """Streamed agent execution response."""
payload: Annotated[ payload: Annotated[
Union[ Union[
AgenticSystemTurnResponseStepStartPayload, AgentTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload, AgentTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload, AgentTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload, AgentTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
], ],
Field(discriminator="event_type"), Field(discriminator="event_type"),
] ]
@json_schema_type @json_schema_type
class AgenticSystemCreateResponse(BaseModel): class AgentCreateResponse(BaseModel):
agent_id: str agent_id: str
@json_schema_type @json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel): class AgentSessionCreateResponse(BaseModel):
session_id: str session_id: str
@json_schema_type @json_schema_type
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn): class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str agent_id: str
session_id: str session_id: str
@ -397,24 +389,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStreamChunk(BaseModel): class AgentTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent event: AgentTurnResponseEvent
@json_schema_type @json_schema_type
class AgenticSystemStepResponse(BaseModel): class AgentStepResponse(BaseModel):
step: Step step: Step
class AgenticSystem(Protocol): class Agents(Protocol):
@webmethod(route="/agentic_system/create") @webmethod(route="/agents/create")
async def create_agentic_system( async def create_agent(
self, self,
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ... ) -> AgentCreateResponse: ...
@webmethod(route="/agentic_system/turn/create") @webmethod(route="/agents/turn/create")
async def create_agentic_system_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
@ -426,42 +418,40 @@ class AgenticSystem(Protocol):
], ],
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AgenticSystemTurnResponseStreamChunk: ... ) -> AgentTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get") @webmethod(route="/agents/turn/get")
async def get_agentic_system_turn( async def get_agents_turn(
self, self,
agent_id: str, agent_id: str,
turn_id: str, turn_id: str,
) -> Turn: ... ) -> Turn: ...
@webmethod(route="/agentic_system/step/get") @webmethod(route="/agents/step/get")
async def get_agentic_system_step( async def get_agents_step(
self, agent_id: str, turn_id: str, step_id: str self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ... ) -> AgentStepResponse: ...
@webmethod(route="/agentic_system/session/create") @webmethod(route="/agents/session/create")
async def create_agentic_system_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgenticSystemSessionCreateResponse: ... ) -> AgentSessionCreateResponse: ...
@webmethod(route="/agentic_system/session/get") @webmethod(route="/agents/session/get")
async def get_agentic_system_session( async def get_agents_session(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ... ) -> Session: ...
@webmethod(route="/agentic_system/session/delete") @webmethod(route="/agents/session/delete")
async def delete_agentic_system_session( async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agentic_system/delete") @webmethod(route="/agents/delete")
async def delete_agentic_system( async def delete_agents(
self, self,
agent_id: str, agent_id: str,
) -> None: ... ) -> None: ...

View file

@ -18,44 +18,42 @@ from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.core.datatypes import RemoteProviderConfig from llama_stack.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .agents import * # noqa: F403
from .event_logger import EventLogger from .event_logger import EventLogger
async def get_client_impl(config: RemoteProviderConfig, _deps): async def get_client_impl(config: RemoteProviderConfig, _deps):
return AgenticSystemClient(config.url) return AgentsClient(config.url)
def encodable_dict(d: BaseModel): def encodable_dict(d: BaseModel):
return json.loads(d.json()) return json.loads(d.json())
class AgenticSystemClient(AgenticSystem): class AgentsClient(Agents):
def __init__(self, base_url: str): def __init__(self, base_url: str):
self.base_url = base_url self.base_url = base_url
async def create_agentic_system( async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
self, agent_config: AgentConfig
) -> AgenticSystemCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/create", f"{self.base_url}/agents/create",
json={ json={
"agent_config": encodable_dict(agent_config), "agent_config": encodable_dict(agent_config),
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return AgenticSystemCreateResponse(**response.json()) return AgentCreateResponse(**response.json())
async def create_agentic_system_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgenticSystemSessionCreateResponse: ) -> AgentSessionCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/session/create", f"{self.base_url}/agents/session/create",
json={ json={
"agent_id": agent_id, "agent_id": agent_id,
"session_name": session_name, "session_name": session_name,
@ -63,16 +61,16 @@ class AgenticSystemClient(AgenticSystem):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return AgenticSystemSessionCreateResponse(**response.json()) return AgentSessionCreateResponse(**response.json())
async def create_agentic_system_turn( async def create_agent_turn(
self, self,
request: AgenticSystemTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/agentic_system/turn/create", f"{self.base_url}/agents/turn/create",
json=encodable_dict(request), json=encodable_dict(request),
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
@ -86,7 +84,7 @@ class AgenticSystemClient(AgenticSystem):
cprint(data, "red") cprint(data, "red")
continue continue
yield AgenticSystemTurnResponseStreamChunk(**jdata) yield AgentTurnResponseStreamChunk(**jdata)
except Exception as e: except Exception as e:
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
@ -102,16 +100,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
tool_prompt_format=ToolPromptFormat.function_tag, tool_prompt_format=ToolPromptFormat.function_tag,
) )
create_response = await api.create_agentic_system(agent_config) create_response = await api.create_agent(agent_config)
session_response = await api.create_agentic_system_session( session_response = await api.create_agent_session(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,
session_name="test_session", session_name="test_session",
) )
for content in user_prompts: for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"]) cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agentic_system_turn( iterator = api.create_agent_turn(
AgenticSystemTurnCreateRequest( AgentTurnCreateRequest(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,
session_id=session_response.session_id, session_id=session_response.session_id,
messages=[ messages=[
@ -128,7 +126,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
async def run_main(host: str, port: int): async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
SearchToolDefinition(engine=SearchEngineType.bing), SearchToolDefinition(engine=SearchEngineType.bing),
@ -165,7 +163,7 @@ async def run_main(host: str, port: int):
async def run_rag(host: str, port: int): async def run_rag(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batch_inference import * # noqa: F401 F403

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@json_schema_type @json_schema_type

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dataset import * # noqa: F401 F403

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .api import * # noqa: F401 F403 from .evals import * # noqa: F401 F403

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.dataset.api import * # noqa: F403 from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.common.training_types import * # noqa: F403 from llama_stack.common.training_types import * # noqa: F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import * # noqa: F401 F403

View file

@ -10,12 +10,14 @@ from typing import Any, AsyncGenerator
import fire import fire
import httpx import httpx
from llama_stack.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.core.datatypes import RemoteProviderConfig from .event_logger import EventLogger
from .api import ( from .inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
@ -23,7 +25,6 @@ from .api import (
Inference, Inference,
UserMessage, UserMessage,
) )
from .event_logger import EventLogger
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .memory import * # noqa: F401 F403

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
from termcolor import cprint
from llama_stack.core.datatypes import RemoteProviderConfig from llama_stack.core.datatypes import RemoteProviderConfig
from termcolor import cprint
from .api import * # noqa: F403 from .memory import * # noqa: F403
from .common.file_utils import data_url_from_file from .common.file_utils import data_url_from_file

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import * # noqa: F401 F403

View file

@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.dataset.api import * # noqa: F403 from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.common.training_types import * # noqa: F403 from llama_stack.common.training_types import * # noqa: F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .reward_scoring import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import * # noqa: F401 F403

View file

@ -13,12 +13,12 @@ import fire
import httpx import httpx
from llama_models.llama3.api.datatypes import UserMessage from llama_models.llama3.api.datatypes import UserMessage
from llama_stack.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.core.datatypes import RemoteProviderConfig from .safety import * # noqa: F403
from .api import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import * # noqa: F401 F403

View file

@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.reward_scoring.api import * # noqa: F403 from llama_stack.apis.reward_scoring import * # noqa: F403
class FilteringFunction(Enum): class FilteringFunction(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import * # noqa: F401 F403

View file

@ -31,16 +31,6 @@ class LlamaCLIParser:
ModelParser.create(subparsers) ModelParser.create(subparsers)
StackParser.create(subparsers) StackParser.create(subparsers)
# Import sub-commands from agentic_system if they exist
try:
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
for module in SUBCOMMAND_MODULES:
module.create(subparsers)
except ImportError:
pass
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args() return self.parser.parse_args()

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: meta-reference inference: meta-reference
memory: meta-reference-faiss memory: meta-reference-faiss
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: console
image_type: conda image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::fireworks inference: remote::fireworks
memory: meta-reference-faiss memory: meta-reference-faiss
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: console
image_type: conda image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::ollama inference: remote::ollama
memory: meta-reference-faiss memory: meta-reference-faiss
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: console
image_type: conda image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::tgi inference: remote::tgi
memory: meta-reference-faiss memory: meta-reference-faiss
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: console
image_type: conda image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::together inference: remote::together
memory: meta-reference-faiss memory: meta-reference-faiss
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: console
image_type: conda image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: meta-reference inference: meta-reference
memory: meta-reference-faiss memory: meta-reference-faiss
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: console
image_type: docker image_type: docker

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
class Api(Enum): class Api(Enum):
inference = "inference" inference = "inference"
safety = "safety" safety = "safety"
agentic_system = "agentic_system" agents = "agents"
memory = "memory" memory = "memory"
telemetry = "telemetry" telemetry = "telemetry"

View file

@ -8,11 +8,11 @@ import importlib
import inspect import inspect
from typing import Dict, List from typing import Dict, List
from llama_stack.agentic_system.api import AgenticSystem from llama_stack.apis.agents import Agents
from llama_stack.inference.api import Inference from llama_stack.apis.inference import Inference
from llama_stack.memory.api import Memory from llama_stack.apis.memory import Memory
from llama_stack.safety.api import Safety from llama_stack.apis.safety import Safety
from llama_stack.telemetry.api import Telemetry from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
@ -34,7 +34,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
protocols = { protocols = {
Api.inference: Inference, Api.inference: Inference,
Api.safety: Safety, Api.safety: Safety,
Api.agentic_system: AgenticSystem, Api.agents: Agents,
Api.memory: Memory, Api.memory: Memory,
Api.telemetry: Telemetry, Api.telemetry: Telemetry,
} }
@ -67,7 +67,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {} ret = {}
for api in stack_apis(): for api in stack_apis():
name = api.name.lower() name = api.name.lower()
module = importlib.import_module(f"llama_stack.{name}.providers") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = { ret[api] = {
"remote": remote_provider_spec(api), "remote": remote_provider_spec(api),
**{a.provider_id: a for a in module.available_providers()}, **{a.provider_id: a for a in module.available_providers()},

View file

@ -39,7 +39,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
SpanStatus, SpanStatus,

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -6,15 +6,16 @@
from typing import AsyncGenerator from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.inference.api import * # noqa: F403 from fireworks.client import Fireworks
from llama_stack.inference.prepare_messages import prepare_messages
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import FireworksImplConfig from .config import FireworksImplConfig
@ -81,7 +82,7 @@ class FireworksInferenceAdapter(Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(), tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -91,7 +92,7 @@ class FireworksInferenceAdapter(Inference):
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,

View file

@ -12,10 +12,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
@ -89,7 +90,7 @@ class OllamaInferenceAdapter(Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(), tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -99,7 +100,7 @@ class OllamaInferenceAdapter(Inference):
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,

View file

@ -13,8 +13,8 @@ from huggingface_hub import HfApi, InferenceClient
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import TGIImplConfig from .config import TGIImplConfig
@ -87,7 +87,7 @@ class TGIAdapter(Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(), tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -97,7 +97,7 @@ class TGIAdapter(Inference):
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,

View file

@ -11,10 +11,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from together import Together from together import Together
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -81,7 +82,7 @@ class TogetherInferenceAdapter(Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(), tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(Inference):
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,

View file

@ -12,10 +12,13 @@ from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.memory.api import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.memory.common.vector_store import BankWithIndex, EmbeddingIndex from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
class ChromaIndex(EmbeddingIndex): class ChromaIndex(EmbeddingIndex):

View file

@ -13,10 +13,10 @@ from numpy.typing import NDArray
from psycopg2 import sql from psycopg2 import sql
from psycopg2.extras import execute_values, Json from psycopg2.extras import execute_values, Json
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.memory.api import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.memory.common.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION, ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,

View file

@ -14,13 +14,13 @@ from .config import MetaReferenceImplConfig
async def get_provider_impl( async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
): ):
from .agentic_system import MetaReferenceAgenticSystemImpl from .agents import MetaReferenceAgentsImpl
assert isinstance( assert isinstance(
config, MetaReferenceImplConfig config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl( impl = MetaReferenceAgentsImpl(
config, config,
deps[Api.inference], deps[Api.inference],
deps[Api.memory], deps[Api.memory],

View file

@ -20,10 +20,10 @@ import httpx
from termcolor import cprint from termcolor import cprint
from llama_stack.agentic_system.api import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.memory.api import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.safety.api import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.tools.base import BaseTool from llama_stack.tools.base import BaseTool
from llama_stack.tools.builtin import ( from llama_stack.tools.builtin import (
@ -122,7 +122,7 @@ class ChatAgent(ShieldRunnerMixin):
return session return session
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgenticSystemTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
assert ( assert (
request.session_id in self.sessions request.session_id in self.sessions
@ -141,9 +141,9 @@ class ChatAgent(ShieldRunnerMixin):
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
start_time = datetime.now() start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseTurnStartPayload( payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id, turn_id=turn_id,
) )
) )
@ -169,12 +169,12 @@ class ChatAgent(ShieldRunnerMixin):
continue continue
assert isinstance( assert isinstance(
chunk, AgenticSystemTurnResponseStreamChunk chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}" ), f"Unexpected type {type(chunk)}"
event = chunk.event event = chunk.event
if ( if (
event.payload.event_type event.payload.event_type
== AgenticSystemTurnResponseEventType.step_complete.value == AgentTurnResponseEventType.step_complete.value
): ):
steps.append(event.payload.step_details) steps.append(event.payload.step_details)
@ -193,9 +193,9 @@ class ChatAgent(ShieldRunnerMixin):
) )
session.turns.append(turn) session.turns.append(turn)
chunk = AgenticSystemTurnResponseStreamChunk( chunk = AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseTurnCompletePayload( payload=AgentTurnResponseTurnCompletePayload(
turn=turn, turn=turn,
) )
) )
@ -261,9 +261,9 @@ class ChatAgent(ShieldRunnerMixin):
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
try: try:
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call.value,
step_id=step_id, step_id=step_id,
metadata=dict(touchpoint=touchpoint), metadata=dict(touchpoint=touchpoint),
@ -273,9 +273,9 @@ class ChatAgent(ShieldRunnerMixin):
await self.run_shields(messages, shields) await self.run_shields(messages, shields)
except SafetyException as e: except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call.value,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
@ -292,9 +292,9 @@ class ChatAgent(ShieldRunnerMixin):
) )
yield False yield False
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call.value,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
@ -325,9 +325,9 @@ class ChatAgent(ShieldRunnerMixin):
) )
if need_rag_context: if need_rag_context:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value, step_type=StepType.memory_retrieval.value,
step_id=step_id, step_id=step_id,
) )
@ -341,9 +341,9 @@ class ChatAgent(ShieldRunnerMixin):
) )
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value, step_type=StepType.memory_retrieval.value,
step_id=step_id, step_id=step_id,
step_details=MemoryRetrievalStep( step_details=MemoryRetrievalStep(
@ -360,7 +360,7 @@ class ChatAgent(ShieldRunnerMixin):
last_message = input_messages[-1] last_message = input_messages[-1]
last_message.context = "\n".join(rag_context) last_message.context = "\n".join(rag_context)
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools: elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)] urls = [a.content for a in attachments if isinstance(a.content, URL)]
msg = await attachment_message(self.tempdir, urls) msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg) input_messages.append(msg)
@ -379,9 +379,9 @@ class ChatAgent(ShieldRunnerMixin):
cprint(f"{str(msg)}", color=color) cprint(f"{str(msg)}", color=color)
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
) )
@ -412,9 +412,9 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls.append(delta.content) tool_calls.append(delta.content)
if stream: if stream:
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
model_response_text_delta="", model_response_text_delta="",
@ -426,9 +426,9 @@ class ChatAgent(ShieldRunnerMixin):
elif isinstance(delta, str): elif isinstance(delta, str):
content += delta content += delta
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
model_response_text_delta=event.delta, model_response_text_delta=event.delta,
@ -448,9 +448,9 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls=tool_calls, tool_calls=tool_calls,
) )
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
step_details=InferenceStep( step_details=InferenceStep(
@ -498,17 +498,17 @@ class ChatAgent(ShieldRunnerMixin):
return return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution.value,
step_id=step_id, step_id=step_id,
) )
) )
) )
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution.value,
step_id=step_id, step_id=step_id,
tool_call=tool_call, tool_call=tool_call,
@ -525,9 +525,9 @@ class ChatAgent(ShieldRunnerMixin):
), "Currently not supporting multiple messages" ), "Currently not supporting multiple messages"
result_message = result_messages[0] result_message = result_messages[0]
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution.value,
step_details=ToolExecutionStep( step_details=ToolExecutionStep(
step_id=step_id, step_id=step_id,
@ -547,9 +547,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also # TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially # but that needs a lot more refactoring of Tool code potentially
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call.value,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=str(uuid.uuid4()), step_id=str(uuid.uuid4()),
@ -566,9 +566,9 @@ class ChatAgent(ShieldRunnerMixin):
) )
except SafetyException as e: except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call.value,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=str(uuid.uuid4()), step_id=str(uuid.uuid4()),
@ -616,18 +616,18 @@ class ChatAgent(ShieldRunnerMixin):
enabled_tools = set(t.type for t in self.agent_config.tools) enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments: if attachments:
if ( if (
AgenticSystemTool.code_interpreter.value in enabled_tools AgentTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required and self.agent_config.tool_choice == ToolChoice.required
): ):
return False return False
else: else:
return True return True
return AgenticSystemTool.memory.value in enabled_tools return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools: for t in self.agent_config.tools:
if t.type == AgenticSystemTool.memory.value: if t.type == AgentTool.memory.value:
return t return t
return None return None

View file

@ -10,10 +10,10 @@ import tempfile
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_stack.inference.api import Inference from llama_stack.apis.inference import Inference
from llama_stack.memory.api import Memory from llama_stack.apis.memory import Memory
from llama_stack.safety.api import Safety from llama_stack.apis.safety import Safety
from llama_stack.agentic_system.api import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.tools.builtin import ( from llama_stack.tools.builtin import (
CodeInterpreterTool, CodeInterpreterTool,
PhotogenTool, PhotogenTool,
@ -33,7 +33,7 @@ logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {} AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgenticSystemImpl(AgenticSystem): class MetaReferenceAgentsImpl(Agents):
def __init__( def __init__(
self, self,
config: MetaReferenceImplConfig, config: MetaReferenceImplConfig,
@ -49,10 +49,10 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def create_agentic_system( async def create_agent(
self, self,
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ) -> AgentCreateResponse:
agent_id = str(uuid.uuid4()) agent_id = str(uuid.uuid4())
builtin_tools = [] builtin_tools = []
@ -95,24 +95,24 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
builtin_tools=builtin_tools, builtin_tools=builtin_tools,
) )
return AgenticSystemCreateResponse( return AgentCreateResponse(
agent_id=agent_id, agent_id=agent_id,
) )
async def create_agentic_system_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgenticSystemSessionCreateResponse: ) -> AgentSessionCreateResponse:
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id] agent = AGENT_INSTANCES_BY_ID[agent_id]
session = agent.create_session(session_name) session = agent.create_session(session_name)
return AgenticSystemSessionCreateResponse( return AgentSessionCreateResponse(
session_id=session.session_id, session_id=session.session_id,
) )
async def create_agentic_system_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
@ -126,7 +126,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API) # wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgenticSystemTurnCreateRequest( request = AgentTurnCreateRequest(
agent_id=agent_id, agent_id=agent_id,
session_id=session_id, session_id=session_id,
messages=messages, messages=messages,

View file

@ -10,14 +10,14 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
from llama_stack.agentic_system.api import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator, MemoryQueryGenerator,
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from termcolor import cprint # noqa: F401 from termcolor import cprint # noqa: F401
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
async def generate_rag_query( async def generate_rag_query(

View file

@ -7,15 +7,15 @@
from typing import List from typing import List
from llama_models.llama3.api.datatypes import Message, Role, UserMessage from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_stack.safety.api import ( from llama_stack.apis.safety import (
OnViolationAction, OnViolationAction,
RunShieldRequest, RunShieldRequest,
Safety, Safety,
ShieldDefinition, ShieldDefinition,
ShieldResponse, ShieldResponse,
) )
from termcolor import cprint
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818

View file

@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models, resolve_model from llama_models.sku_list import all_registered_models, resolve_model
from pydantic import BaseModel, Field, field_validator from llama_stack.apis.inference import QuantizationConfig
from llama_stack.inference.api import QuantizationConfig from pydantic import BaseModel, Field, field_validator
@json_schema_type @json_schema_type

View file

@ -28,10 +28,10 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import cprint from llama_stack.apis.inference import QuantizationType
from llama_stack.common.model_utils import model_local_dir from llama_stack.common.model_utils import model_local_dir
from llama_stack.inference.api import QuantizationType from termcolor import cprint
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig

View file

@ -11,7 +11,7 @@ from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.inference.api import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
@ -21,13 +21,13 @@ from llama_stack.inference.api import (
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_stack.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process. # we don't support multiple concurrent requests to this process.
@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(), tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference):
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,

View file

@ -14,9 +14,9 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_stack.inference.api import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.inference.api.config import ( from llama_stack.apis.inference.config import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
MetaReferenceImplConfig, MetaReferenceImplConfig,
) )

View file

@ -15,13 +15,14 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.memory.api import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.memory.common.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION, ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
from llama_stack.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
from .config import FaissImplConfig from .config import FaissImplConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -9,7 +9,7 @@ import asyncio
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.common.model_utils import model_local_dir from llama_stack.common.model_utils import model_local_dir
from llama_stack.safety.api import * # noqa from llama_stack.apis.safety import * # noqa
from .config import SafetyConfig from .config import SafetyConfig
from .shields import ( from .shields import (

View file

@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from typing import List from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from llama_stack.safety.api import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -8,7 +8,7 @@ from codeshield.cs import CodeShield
from termcolor import cprint from termcolor import cprint
from .base import ShieldResponse, TextShield from .base import ShieldResponse, TextShield
from llama_stack.safety.api import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
class CodeScannerShield(TextShield): class CodeScannerShield(TextShield):

View file

@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.safety.api import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
SAFE_RESPONSE = "safe" SAFE_RESPONSE = "safe"
_INSTANCE = None _INSTANCE = None

View file

@ -14,7 +14,7 @@ from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.safety.api import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
class PromptGuardShield(TextShield): class PromptGuardShield(TextShield):

View file

@ -6,7 +6,7 @@
from typing import Optional from typing import Optional
from llama_stack.telemetry.api import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
from .config import ConsoleConfig from .config import ConsoleConfig

View file

@ -12,7 +12,7 @@ from llama_stack.core.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:
return [ return [
InlineProviderSpec( InlineProviderSpec(
api=Api.agentic_system, api=Api.agents,
provider_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
@ -23,8 +23,8 @@ def available_providers() -> List[ProviderSpec]:
"torch", "torch",
"transformers", "transformers",
], ],
module="llama_stack.agentic_system.meta_reference", module="llama_stack.providers.impls.meta_reference.agents",
config_class="llama_stack.agentic_system.meta_reference.MetaReferenceImplConfig", config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety, Api.safety,

Some files were not shown because too many files have changed in this diff Show more