mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Codemod from llama_toolchain -> llama_stack
- added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis.<api> import foo should work now - update imports to do llama_stack.apis.<api> - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent
This commit is contained in:
parent
2cf731faea
commit
76b354a081
128 changed files with 381 additions and 376 deletions
|
@ -482,7 +482,7 @@ Once the server is setup, we can test it with a client to see the example output
|
||||||
cd /path/to/llama-stack
|
cd /path/to/llama-stack
|
||||||
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
||||||
|
|
||||||
python -m llama_stack.inference.client localhost 5000
|
python -m llama_stack.apis.inference.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
||||||
|
|
|
@ -296,7 +296,7 @@ Once the server is setup, we can test it with a client to see the example output
|
||||||
cd /path/to/llama-stack
|
cd /path/to/llama-stack
|
||||||
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
||||||
|
|
||||||
python -m llama_stack.inference.client localhost 5000
|
python -m llama_stack.apis.inference.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
||||||
|
@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo
|
||||||
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
|
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m llama_stack.safety.client localhost 5000
|
python -m llama_stack.apis.safety.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
|
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
|
||||||
|
|
7
llama_stack/apis/agents/__init__.py
Normal file
7
llama_stack/apis/agents/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .agents import * # noqa: F401 F403
|
|
@ -15,9 +15,9 @@ from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.common.deployment_types import * # noqa: F403
|
from llama_stack.common.deployment_types import * # noqa: F403
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.safety.api import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_stack.memory.api import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -26,7 +26,7 @@ class Attachment(BaseModel):
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemTool(Enum):
|
class AgentTool(Enum):
|
||||||
brave_search = "brave_search"
|
brave_search = "brave_search"
|
||||||
wolfram_alpha = "wolfram_alpha"
|
wolfram_alpha = "wolfram_alpha"
|
||||||
photogen = "photogen"
|
photogen = "photogen"
|
||||||
|
@ -50,41 +50,33 @@ class SearchEngineType(Enum):
|
||||||
class SearchToolDefinition(ToolDefinitionCommon):
|
class SearchToolDefinition(ToolDefinitionCommon):
|
||||||
# NOTE: brave_search is just a placeholder since model always uses
|
# NOTE: brave_search is just a placeholder since model always uses
|
||||||
# brave_search as tool call name
|
# brave_search as tool call name
|
||||||
type: Literal[AgenticSystemTool.brave_search.value] = (
|
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
||||||
AgenticSystemTool.brave_search.value
|
|
||||||
)
|
|
||||||
engine: SearchEngineType = SearchEngineType.brave
|
engine: SearchEngineType = SearchEngineType.brave
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
|
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
||||||
AgenticSystemTool.wolfram_alpha.value
|
|
||||||
)
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
|
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.code_interpreter.value] = (
|
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
||||||
AgenticSystemTool.code_interpreter.value
|
|
||||||
)
|
|
||||||
enable_inline_code_execution: bool = True
|
enable_inline_code_execution: bool = True
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.function_call.value] = (
|
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
||||||
AgenticSystemTool.function_call.value
|
|
||||||
)
|
|
||||||
function_name: str
|
function_name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: Dict[str, ToolParamDefinition]
|
parameters: Dict[str, ToolParamDefinition]
|
||||||
|
@ -95,30 +87,30 @@ class _MemoryBankConfigCommon(BaseModel):
|
||||||
bank_id: str
|
bank_id: str
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
keys: List[str] # what keys to focus on
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
entities: List[str] # what entities to focus on
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
MemoryBankConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
AgenticSystemVectorMemoryBankConfig,
|
AgentVectorMemoryBankConfig,
|
||||||
AgenticSystemKeyValueMemoryBankConfig,
|
AgentKeyValueMemoryBankConfig,
|
||||||
AgenticSystemKeywordMemoryBankConfig,
|
AgentKeywordMemoryBankConfig,
|
||||||
AgenticSystemGraphMemoryBankConfig,
|
AgentGraphMemoryBankConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
@ -158,7 +150,7 @@ MemoryQueryGeneratorConfig = Annotated[
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
# for memory bank retrieval.
|
# for memory bank retrieval.
|
||||||
|
@ -169,7 +161,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
max_chunks: int = 10
|
max_chunks: int = 10
|
||||||
|
|
||||||
|
|
||||||
AgenticSystemToolDefinition = Annotated[
|
AgentToolDefinition = Annotated[
|
||||||
Union[
|
Union[
|
||||||
SearchToolDefinition,
|
SearchToolDefinition,
|
||||||
WolframAlphaToolDefinition,
|
WolframAlphaToolDefinition,
|
||||||
|
@ -275,7 +267,7 @@ class AgentConfigCommon(BaseModel):
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
@ -292,7 +284,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: Optional[str] = None
|
instructions: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemTurnResponseEventType(Enum):
|
class AgentTurnResponseEventType(Enum):
|
||||||
step_start = "step_start"
|
step_start = "step_start"
|
||||||
step_complete = "step_complete"
|
step_complete = "step_complete"
|
||||||
step_progress = "step_progress"
|
step_progress = "step_progress"
|
||||||
|
@ -302,9 +294,9 @@ class AgenticSystemTurnResponseEventType(Enum):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
|
||||||
AgenticSystemTurnResponseEventType.step_start.value
|
AgentTurnResponseEventType.step_start.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
@ -312,20 +304,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
|
||||||
AgenticSystemTurnResponseEventType.step_complete.value
|
AgentTurnResponseEventType.step_complete.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
|
||||||
AgenticSystemTurnResponseEventType.step_progress.value
|
AgentTurnResponseEventType.step_progress.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
@ -336,49 +328,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
|
||||||
AgenticSystemTurnResponseEventType.turn_start.value
|
AgentTurnResponseEventType.turn_start.value
|
||||||
)
|
)
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
|
||||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
AgentTurnResponseEventType.turn_complete.value
|
||||||
)
|
)
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseEvent(BaseModel):
|
class AgentTurnResponseEvent(BaseModel):
|
||||||
"""Streamed agent execution response."""
|
"""Streamed agent execution response."""
|
||||||
|
|
||||||
payload: Annotated[
|
payload: Annotated[
|
||||||
Union[
|
Union[
|
||||||
AgenticSystemTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgenticSystemTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgenticSystemTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgenticSystemTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
AgenticSystemTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
],
|
],
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemCreateResponse(BaseModel):
|
class AgentCreateResponse(BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemSessionCreateResponse(BaseModel):
|
class AgentSessionCreateResponse(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
|
@ -397,24 +389,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
event: AgenticSystemTurnResponseEvent
|
event: AgentTurnResponseEvent
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemStepResponse(BaseModel):
|
class AgentStepResponse(BaseModel):
|
||||||
step: Step
|
step: Step
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystem(Protocol):
|
class Agents(Protocol):
|
||||||
@webmethod(route="/agentic_system/create")
|
@webmethod(route="/agents/create")
|
||||||
async def create_agentic_system(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgenticSystemCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/create")
|
@webmethod(route="/agents/turn/create")
|
||||||
async def create_agentic_system_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -426,42 +418,40 @@ class AgenticSystem(Protocol):
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
attachments: Optional[List[Attachment]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
) -> AgentTurnResponseStreamChunk: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/get")
|
@webmethod(route="/agents/turn/get")
|
||||||
async def get_agentic_system_turn(
|
async def get_agents_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
) -> Turn: ...
|
) -> Turn: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/step/get")
|
@webmethod(route="/agents/step/get")
|
||||||
async def get_agentic_system_step(
|
async def get_agents_step(
|
||||||
self, agent_id: str, turn_id: str, step_id: str
|
self, agent_id: str, turn_id: str, step_id: str
|
||||||
) -> AgenticSystemStepResponse: ...
|
) -> AgentStepResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/create")
|
@webmethod(route="/agents/session/create")
|
||||||
async def create_agentic_system_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse: ...
|
) -> AgentSessionCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/get")
|
@webmethod(route="/agents/session/get")
|
||||||
async def get_agentic_system_session(
|
async def get_agents_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: Optional[List[str]] = None,
|
||||||
) -> Session: ...
|
) -> Session: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/delete")
|
@webmethod(route="/agents/session/delete")
|
||||||
async def delete_agentic_system_session(
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
|
||||||
self, agent_id: str, session_id: str
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/delete")
|
@webmethod(route="/agents/delete")
|
||||||
async def delete_agentic_system(
|
async def delete_agents(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None: ...
|
) -> None: ...
|
|
@ -18,44 +18,42 @@ from termcolor import cprint
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .agents import * # noqa: F403
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
||||||
return AgenticSystemClient(config.url)
|
return AgentsClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
def encodable_dict(d: BaseModel):
|
def encodable_dict(d: BaseModel):
|
||||||
return json.loads(d.json())
|
return json.loads(d.json())
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClient(AgenticSystem):
|
class AgentsClient(Agents):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
async def create_agentic_system(
|
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
||||||
self, agent_config: AgentConfig
|
|
||||||
) -> AgenticSystemCreateResponse:
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agentic_system/create",
|
f"{self.base_url}/agents/create",
|
||||||
json={
|
json={
|
||||||
"agent_config": encodable_dict(agent_config),
|
"agent_config": encodable_dict(agent_config),
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return AgenticSystemCreateResponse(**response.json())
|
return AgentCreateResponse(**response.json())
|
||||||
|
|
||||||
async def create_agentic_system_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse:
|
) -> AgentSessionCreateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agentic_system/session/create",
|
f"{self.base_url}/agents/session/create",
|
||||||
json={
|
json={
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
"session_name": session_name,
|
"session_name": session_name,
|
||||||
|
@ -63,16 +61,16 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return AgenticSystemSessionCreateResponse(**response.json())
|
return AgentSessionCreateResponse(**response.json())
|
||||||
|
|
||||||
async def create_agentic_system_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemTurnCreateRequest,
|
request: AgentTurnCreateRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/agentic_system/turn/create",
|
f"{self.base_url}/agents/turn/create",
|
||||||
json=encodable_dict(request),
|
json=encodable_dict(request),
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
|
@ -86,7 +84,7 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
cprint(data, "red")
|
cprint(data, "red")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(**jdata)
|
yield AgentTurnResponseStreamChunk(**jdata)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
@ -102,16 +100,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||||
)
|
)
|
||||||
|
|
||||||
create_response = await api.create_agentic_system(agent_config)
|
create_response = await api.create_agent(agent_config)
|
||||||
session_response = await api.create_agentic_system_session(
|
session_response = await api.create_agent_session(
|
||||||
agent_id=create_response.agent_id,
|
agent_id=create_response.agent_id,
|
||||||
session_name="test_session",
|
session_name="test_session",
|
||||||
)
|
)
|
||||||
|
|
||||||
for content in user_prompts:
|
for content in user_prompts:
|
||||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||||
iterator = api.create_agentic_system_turn(
|
iterator = api.create_agent_turn(
|
||||||
AgenticSystemTurnCreateRequest(
|
AgentTurnCreateRequest(
|
||||||
agent_id=create_response.agent_id,
|
agent_id=create_response.agent_id,
|
||||||
session_id=session_response.session_id,
|
session_id=session_response.session_id,
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -128,7 +126,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
SearchToolDefinition(engine=SearchEngineType.bing),
|
SearchToolDefinition(engine=SearchEngineType.bing),
|
||||||
|
@ -165,7 +163,7 @@ async def run_main(host: str, port: int):
|
||||||
|
|
||||||
|
|
||||||
async def run_rag(host: str, port: int):
|
async def run_rag(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
7
llama_stack/apis/batch_inference/__init__.py
Normal file
7
llama_stack/apis/batch_inference/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .batch_inference import * # noqa: F401 F403
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
7
llama_stack/apis/dataset/__init__.py
Normal file
7
llama_stack/apis/dataset/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .dataset import * # noqa: F401 F403
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .api import * # noqa: F401 F403
|
from .evals import * # noqa: F401 F403
|
|
@ -12,7 +12,7 @@ from llama_models.schema_utils import webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.dataset.api import * # noqa: F403
|
from llama_stack.apis.dataset import * # noqa: F403
|
||||||
from llama_stack.common.training_types import * # noqa: F403
|
from llama_stack.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
7
llama_stack/apis/inference/__init__.py
Normal file
7
llama_stack/apis/inference/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .inference import * # noqa: F401 F403
|
|
@ -10,12 +10,14 @@ from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
from .api import (
|
from .inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
@ -23,7 +25,6 @@ from .api import (
|
||||||
Inference,
|
Inference,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from .event_logger import EventLogger
|
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
7
llama_stack/apis/memory/__init__.py
Normal file
7
llama_stack/apis/memory/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .memory import * # noqa: F401 F403
|
|
@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .memory import * # noqa: F403
|
||||||
from .common.file_utils import data_url_from_file
|
from .common.file_utils import data_url_from_file
|
||||||
|
|
||||||
|
|
7
llama_stack/apis/models/__init__.py
Normal file
7
llama_stack/apis/models/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .models import * # noqa: F401 F403
|
7
llama_stack/apis/post_training/__init__.py
Normal file
7
llama_stack/apis/post_training/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .post_training import * # noqa: F401 F403
|
|
@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.dataset.api import * # noqa: F403
|
from llama_stack.apis.dataset import * # noqa: F403
|
||||||
from llama_stack.common.training_types import * # noqa: F403
|
from llama_stack.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
7
llama_stack/apis/reward_scoring/__init__.py
Normal file
7
llama_stack/apis/reward_scoring/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .reward_scoring import * # noqa: F401 F403
|
7
llama_stack/apis/safety/__init__.py
Normal file
7
llama_stack/apis/safety/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .safety import * # noqa: F401 F403
|
|
@ -13,12 +13,12 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
from .safety import * # noqa: F403
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
7
llama_stack/apis/synthetic_data_generation/__init__.py
Normal file
7
llama_stack/apis/synthetic_data_generation/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .synthetic_data_generation import * # noqa: F401 F403
|
|
@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.reward_scoring.api import * # noqa: F403
|
from llama_stack.apis.reward_scoring import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
class FilteringFunction(Enum):
|
7
llama_stack/apis/telemetry/__init__.py
Normal file
7
llama_stack/apis/telemetry/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .telemetry import * # noqa: F401 F403
|
|
@ -31,16 +31,6 @@ class LlamaCLIParser:
|
||||||
ModelParser.create(subparsers)
|
ModelParser.create(subparsers)
|
||||||
StackParser.create(subparsers)
|
StackParser.create(subparsers)
|
||||||
|
|
||||||
# Import sub-commands from agentic_system if they exist
|
|
||||||
try:
|
|
||||||
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
|
|
||||||
|
|
||||||
for module in SUBCOMMAND_MODULES:
|
|
||||||
module.create(subparsers)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def parse_args(self) -> argparse.Namespace:
|
def parse_args(self) -> argparse.Namespace:
|
||||||
return self.parser.parse_args()
|
return self.parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
||||||
inference: meta-reference
|
inference: meta-reference
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference-faiss
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: console
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
||||||
inference: remote::fireworks
|
inference: remote::fireworks
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference-faiss
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: console
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
||||||
inference: remote::ollama
|
inference: remote::ollama
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference-faiss
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: console
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
||||||
inference: remote::tgi
|
inference: remote::tgi
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference-faiss
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: console
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
||||||
inference: remote::together
|
inference: remote::together
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference-faiss
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: console
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
||||||
inference: meta-reference
|
inference: meta-reference
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference-faiss
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: console
|
||||||
image_type: docker
|
image_type: docker
|
||||||
|
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
|
||||||
class Api(Enum):
|
class Api(Enum):
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agentic_system = "agentic_system"
|
agents = "agents"
|
||||||
memory = "memory"
|
memory = "memory"
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,11 @@ import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from llama_stack.agentic_system.api import AgenticSystem
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.inference.api import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.memory.api import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.safety.api import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.telemetry.api import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
|
||||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
protocols = {
|
protocols = {
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.agentic_system: AgenticSystem,
|
Api.agents: Agents,
|
||||||
Api.memory: Memory,
|
Api.memory: Memory,
|
||||||
Api.telemetry: Telemetry,
|
Api.telemetry: Telemetry,
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
ret = {}
|
ret = {}
|
||||||
for api in stack_apis():
|
for api in stack_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_stack.{name}.providers")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
ret[api] = {
|
ret[api] = {
|
||||||
"remote": remote_provider_spec(api),
|
"remote": remote_provider_spec(api),
|
||||||
**{a.provider_id: a for a in module.available_providers()},
|
**{a.provider_id: a for a in module.available_providers()},
|
||||||
|
|
|
@ -39,7 +39,7 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .api import * # noqa: F401 F403
|
|
|
@ -1,7 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .api import * # noqa: F401 F403
|
|
|
@ -1,7 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .api import * # noqa: F401 F403
|
|
|
@ -6,15 +6,16 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from fireworks.client import Fireworks
|
||||||
from llama_stack.inference.prepare_messages import prepare_messages
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ class FireworksInferenceAdapter(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = list(),
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -91,7 +92,7 @@ class FireworksInferenceAdapter(Inference):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools,
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
|
@ -12,10 +12,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.inference.prepare_messages import prepare_messages
|
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
|
@ -89,7 +90,7 @@ class OllamaInferenceAdapter(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = list(),
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -99,7 +100,7 @@ class OllamaInferenceAdapter(Inference):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools,
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
|
@ -13,8 +13,8 @@ from huggingface_hub import HfApi, InferenceClient
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.inference.prepare_messages import prepare_messages
|
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import TGIImplConfig
|
from .config import TGIImplConfig
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ class TGIAdapter(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = list(),
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -97,7 +97,7 @@ class TGIAdapter(Inference):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools,
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
|
@ -11,10 +11,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.inference.prepare_messages import prepare_messages
|
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ class TogetherInferenceAdapter(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = list(),
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(Inference):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools,
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
|
@ -12,10 +12,13 @@ from urllib.parse import urlparse
|
||||||
import chromadb
|
import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.memory.api import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.memory.common.vector_store import BankWithIndex, EmbeddingIndex
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChromaIndex(EmbeddingIndex):
|
class ChromaIndex(EmbeddingIndex):
|
|
@ -13,10 +13,10 @@ from numpy.typing import NDArray
|
||||||
from psycopg2 import sql
|
from psycopg2 import sql
|
||||||
from psycopg2.extras import execute_values, Json
|
from psycopg2.extras import execute_values, Json
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from llama_stack.memory.api import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.memory.common.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
|
@ -14,13 +14,13 @@ from .config import MetaReferenceImplConfig
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
):
|
):
|
||||||
from .agentic_system import MetaReferenceAgenticSystemImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, MetaReferenceImplConfig
|
config, MetaReferenceImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = MetaReferenceAgenticSystemImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config,
|
config,
|
||||||
deps[Api.inference],
|
deps[Api.inference],
|
||||||
deps[Api.memory],
|
deps[Api.memory],
|
|
@ -20,10 +20,10 @@ import httpx
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.agentic_system.api import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.memory.api import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.safety.api import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.tools.base import BaseTool
|
from llama_stack.tools.base import BaseTool
|
||||||
from llama_stack.tools.builtin import (
|
from llama_stack.tools.builtin import (
|
||||||
|
@ -122,7 +122,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return session
|
return session
|
||||||
|
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(
|
||||||
self, request: AgenticSystemTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
assert (
|
assert (
|
||||||
request.session_id in self.sessions
|
request.session_id in self.sessions
|
||||||
|
@ -141,9 +141,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseTurnStartPayload(
|
payload=AgentTurnResponseTurnStartPayload(
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -169,12 +169,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
chunk, AgenticSystemTurnResponseStreamChunk
|
chunk, AgentTurnResponseStreamChunk
|
||||||
), f"Unexpected type {type(chunk)}"
|
), f"Unexpected type {type(chunk)}"
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if (
|
if (
|
||||||
event.payload.event_type
|
event.payload.event_type
|
||||||
== AgenticSystemTurnResponseEventType.step_complete.value
|
== AgentTurnResponseEventType.step_complete.value
|
||||||
):
|
):
|
||||||
steps.append(event.payload.step_details)
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
|
@ -193,9 +193,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
session.turns.append(turn)
|
session.turns.append(turn)
|
||||||
|
|
||||||
chunk = AgenticSystemTurnResponseStreamChunk(
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseTurnCompletePayload(
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
turn=turn,
|
turn=turn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -261,9 +261,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
metadata=dict(touchpoint=touchpoint),
|
metadata=dict(touchpoint=touchpoint),
|
||||||
|
@ -273,9 +273,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
await self.run_shields(messages, shields)
|
await self.run_shields(messages, shields)
|
||||||
|
|
||||||
except SafetyException as e:
|
except SafetyException as e:
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
|
@ -292,9 +292,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
yield False
|
yield False
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
|
@ -325,9 +325,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
if need_rag_context:
|
if need_rag_context:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.memory_retrieval.value,
|
step_type=StepType.memory_retrieval.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
|
@ -341,9 +341,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.memory_retrieval.value,
|
step_type=StepType.memory_retrieval.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=MemoryRetrievalStep(
|
step_details=MemoryRetrievalStep(
|
||||||
|
@ -360,7 +360,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = "\n".join(rag_context)
|
last_message.context = "\n".join(rag_context)
|
||||||
|
|
||||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
||||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||||
msg = await attachment_message(self.tempdir, urls)
|
msg = await attachment_message(self.tempdir, urls)
|
||||||
input_messages.append(msg)
|
input_messages.append(msg)
|
||||||
|
@ -379,9 +379,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
cprint(f"{str(msg)}", color=color)
|
cprint(f"{str(msg)}", color=color)
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
|
@ -412,9 +412,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_calls.append(delta.content)
|
tool_calls.append(delta.content)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
model_response_text_delta="",
|
model_response_text_delta="",
|
||||||
|
@ -426,9 +426,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
elif isinstance(delta, str):
|
elif isinstance(delta, str):
|
||||||
content += delta
|
content += delta
|
||||||
if stream and event.stop_reason is None:
|
if stream and event.stop_reason is None:
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
model_response_text_delta=event.delta,
|
model_response_text_delta=event.delta,
|
||||||
|
@ -448,9 +448,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=InferenceStep(
|
step_details=InferenceStep(
|
||||||
|
@ -498,17 +498,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
|
@ -525,9 +525,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
), "Currently not supporting multiple messages"
|
), "Currently not supporting multiple messages"
|
||||||
result_message = result_messages[0]
|
result_message = result_messages[0]
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_details=ToolExecutionStep(
|
step_details=ToolExecutionStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
|
@ -547,9 +547,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=str(uuid.uuid4()),
|
step_id=str(uuid.uuid4()),
|
||||||
|
@ -566,9 +566,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
except SafetyException as e:
|
except SafetyException as e:
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=str(uuid.uuid4()),
|
step_id=str(uuid.uuid4()),
|
||||||
|
@ -616,18 +616,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||||
if attachments:
|
if attachments:
|
||||||
if (
|
if (
|
||||||
AgenticSystemTool.code_interpreter.value in enabled_tools
|
AgentTool.code_interpreter.value in enabled_tools
|
||||||
and self.agent_config.tool_choice == ToolChoice.required
|
and self.agent_config.tool_choice == ToolChoice.required
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return AgenticSystemTool.memory.value in enabled_tools
|
return AgentTool.memory.value in enabled_tools
|
||||||
|
|
||||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||||
for t in self.agent_config.tools:
|
for t in self.agent_config.tools:
|
||||||
if t.type == AgenticSystemTool.memory.value:
|
if t.type == AgentTool.memory.value:
|
||||||
return t
|
return t
|
||||||
|
|
||||||
return None
|
return None
|
|
@ -10,10 +10,10 @@ import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_stack.inference.api import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.memory.api import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.safety.api import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.agentic_system.api import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.tools.builtin import (
|
from llama_stack.tools.builtin import (
|
||||||
CodeInterpreterTool,
|
CodeInterpreterTool,
|
||||||
PhotogenTool,
|
PhotogenTool,
|
||||||
|
@ -33,7 +33,7 @@ logger.setLevel(logging.INFO)
|
||||||
AGENT_INSTANCES_BY_ID = {}
|
AGENT_INSTANCES_BY_ID = {}
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceImplConfig,
|
config: MetaReferenceImplConfig,
|
||||||
|
@ -49,10 +49,10 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def create_agentic_system(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgenticSystemCreateResponse:
|
) -> AgentCreateResponse:
|
||||||
agent_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
builtin_tools = []
|
builtin_tools = []
|
||||||
|
@ -95,24 +95,24 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
builtin_tools=builtin_tools,
|
builtin_tools=builtin_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgenticSystemCreateResponse(
|
return AgentCreateResponse(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agentic_system_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse:
|
) -> AgentSessionCreateResponse:
|
||||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||||
|
|
||||||
session = agent.create_session(session_name)
|
session = agent.create_session(session_name)
|
||||||
return AgenticSystemSessionCreateResponse(
|
return AgentSessionCreateResponse(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agentic_system_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -126,7 +126,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = AgenticSystemTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
|
@ -10,14 +10,14 @@ from jinja2 import Template
|
||||||
from llama_models.llama3.api import * # noqa: F403
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.agentic_system.api import (
|
from llama_stack.apis.agents import (
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
LLMMemoryQueryGeneratorConfig,
|
LLMMemoryQueryGeneratorConfig,
|
||||||
MemoryQueryGenerator,
|
MemoryQueryGenerator,
|
||||||
MemoryQueryGeneratorConfig,
|
MemoryQueryGeneratorConfig,
|
||||||
)
|
)
|
||||||
from termcolor import cprint # noqa: F401
|
from termcolor import cprint # noqa: F401
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
|
@ -7,15 +7,15 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.safety.api import (
|
from llama_stack.apis.safety import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
RunShieldRequest,
|
RunShieldRequest,
|
||||||
Safety,
|
Safety,
|
||||||
ShieldDefinition,
|
ShieldDefinition,
|
||||||
ShieldResponse,
|
ShieldResponse,
|
||||||
)
|
)
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
|
@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from llama_models.sku_list import all_registered_models, resolve_model
|
from llama_models.sku_list import all_registered_models, resolve_model
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from llama_stack.apis.inference import QuantizationConfig
|
||||||
|
|
||||||
from llama_stack.inference.api import QuantizationConfig
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -28,10 +28,10 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from termcolor import cprint
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
||||||
from llama_stack.common.model_utils import model_local_dir
|
from llama_stack.common.model_utils import model_local_dir
|
||||||
from llama_stack.inference.api import QuantizationType
|
from termcolor import cprint
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
|
@ -11,7 +11,7 @@ from typing import AsyncIterator, Union
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.inference.api import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
|
@ -21,13 +21,13 @@ from llama_stack.inference.api import (
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.inference.prepare_messages import prepare_messages
|
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
# we don't support multiple concurrent requests to this process.
|
# we don't support multiple concurrent requests to this process.
|
||||||
|
@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = list(),
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools,
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
|
@ -14,9 +14,9 @@ import torch
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||||
from llama_stack.inference.api import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
||||||
from llama_stack.inference.api.config import (
|
from llama_stack.apis.inference.config import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
MetaReferenceImplConfig,
|
MetaReferenceImplConfig,
|
||||||
)
|
)
|
|
@ -15,13 +15,14 @@ from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.memory.api import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.memory.common.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
from llama_stack.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
|
@ -9,7 +9,7 @@ import asyncio
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.common.model_utils import model_local_dir
|
from llama_stack.common.model_utils import model_local_dir
|
||||||
from llama_stack.safety.api import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
|
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
from .shields import (
|
from .shields import (
|
|
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||||
from llama_stack.safety.api import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
|
@ -8,7 +8,7 @@ from codeshield.cs import CodeShield
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .base import ShieldResponse, TextShield
|
from .base import ShieldResponse, TextShield
|
||||||
from llama_stack.safety.api import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class CodeScannerShield(TextShield):
|
class CodeScannerShield(TextShield):
|
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
from llama_stack.safety.api import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
SAFE_RESPONSE = "safe"
|
SAFE_RESPONSE = "safe"
|
||||||
_INSTANCE = None
|
_INSTANCE = None
|
|
@ -14,7 +14,7 @@ from termcolor import cprint
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||||
from llama_stack.safety.api import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardShield(TextShield):
|
class PromptGuardShield(TextShield):
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_stack.telemetry.api import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
from .config import ConsoleConfig
|
from .config import ConsoleConfig
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.agentic_system,
|
api=Api.agents,
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
|
@ -23,8 +23,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
"torch",
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
],
|
],
|
||||||
module="llama_stack.agentic_system.meta_reference",
|
module="llama_stack.providers.impls.meta_reference.agents",
|
||||||
config_class="llama_stack.agentic_system.meta_reference.MetaReferenceImplConfig",
|
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
Api.safety,
|
Api.safety,
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue