mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Codemod from llama_toolchain -> llama_stack
- added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis.<api> import foo should work now - update imports to do llama_stack.apis.<api> - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent
This commit is contained in:
parent
2cf731faea
commit
76b354a081
128 changed files with 381 additions and 376 deletions
|
@ -482,7 +482,7 @@ Once the server is setup, we can test it with a client to see the example output
|
|||
cd /path/to/llama-stack
|
||||
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
||||
|
||||
python -m llama_stack.inference.client localhost 5000
|
||||
python -m llama_stack.apis.inference.client localhost 5000
|
||||
```
|
||||
|
||||
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
||||
|
|
|
@ -296,7 +296,7 @@ Once the server is setup, we can test it with a client to see the example output
|
|||
cd /path/to/llama-stack
|
||||
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
||||
|
||||
python -m llama_stack.inference.client localhost 5000
|
||||
python -m llama_stack.apis.inference.client localhost 5000
|
||||
```
|
||||
|
||||
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
||||
|
@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo
|
|||
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
|
||||
|
||||
```
|
||||
python -m llama_stack.safety.client localhost 5000
|
||||
python -m llama_stack.apis.safety.client localhost 5000
|
||||
```
|
||||
|
||||
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
|
||||
|
|
7
llama_stack/apis/agents/__init__.py
Normal file
7
llama_stack/apis/agents/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .agents import * # noqa: F401 F403
|
|
@ -15,9 +15,9 @@ from typing_extensions import Annotated
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.common.deployment_types import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.safety.api import * # noqa: F403
|
||||
from llama_stack.memory.api import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -26,7 +26,7 @@ class Attachment(BaseModel):
|
|||
mime_type: str
|
||||
|
||||
|
||||
class AgenticSystemTool(Enum):
|
||||
class AgentTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
|
@ -50,41 +50,33 @@ class SearchEngineType(Enum):
|
|||
class SearchToolDefinition(ToolDefinitionCommon):
|
||||
# NOTE: brave_search is just a placeholder since model always uses
|
||||
# brave_search as tool call name
|
||||
type: Literal[AgenticSystemTool.brave_search.value] = (
|
||||
AgenticSystemTool.brave_search.value
|
||||
)
|
||||
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
||||
engine: SearchEngineType = SearchEngineType.brave
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
|
||||
AgenticSystemTool.wolfram_alpha.value
|
||||
)
|
||||
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
|
||||
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.code_interpreter.value] = (
|
||||
AgenticSystemTool.code_interpreter.value
|
||||
)
|
||||
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
||||
enable_inline_code_execution: bool = True
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.function_call.value] = (
|
||||
AgenticSystemTool.function_call.value
|
||||
)
|
||||
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
||||
function_name: str
|
||||
description: str
|
||||
parameters: Dict[str, ToolParamDefinition]
|
||||
|
@ -95,30 +87,30 @@ class _MemoryBankConfigCommon(BaseModel):
|
|||
bank_id: str
|
||||
|
||||
|
||||
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
|
||||
|
||||
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
AgenticSystemVectorMemoryBankConfig,
|
||||
AgenticSystemKeyValueMemoryBankConfig,
|
||||
AgenticSystemKeywordMemoryBankConfig,
|
||||
AgenticSystemGraphMemoryBankConfig,
|
||||
AgentVectorMemoryBankConfig,
|
||||
AgentKeyValueMemoryBankConfig,
|
||||
AgentKeywordMemoryBankConfig,
|
||||
AgentGraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
@ -158,7 +150,7 @@ MemoryQueryGeneratorConfig = Annotated[
|
|||
|
||||
|
||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
|
@ -169,7 +161,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
|
|||
max_chunks: int = 10
|
||||
|
||||
|
||||
AgenticSystemToolDefinition = Annotated[
|
||||
AgentToolDefinition = Annotated[
|
||||
Union[
|
||||
SearchToolDefinition,
|
||||
WolframAlphaToolDefinition,
|
||||
|
@ -275,7 +267,7 @@ class AgentConfigCommon(BaseModel):
|
|||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
|
@ -292,7 +284,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
|||
instructions: Optional[str] = None
|
||||
|
||||
|
||||
class AgenticSystemTurnResponseEventType(Enum):
|
||||
class AgentTurnResponseEventType(Enum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
@ -302,9 +294,9 @@ class AgenticSystemTurnResponseEventType(Enum):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_start.value
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
|
||||
AgentTurnResponseEventType.step_start.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
@ -312,20 +304,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_complete.value
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
|
||||
AgentTurnResponseEventType.step_complete.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_details: Step
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_progress.value
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
|
||||
AgentTurnResponseEventType.step_progress.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
@ -336,49 +328,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_start.value
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
|
||||
AgentTurnResponseEventType.turn_start.value
|
||||
)
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
|
||||
AgentTurnResponseEventType.turn_complete.value
|
||||
)
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseEvent(BaseModel):
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
"""Streamed agent execution response."""
|
||||
|
||||
payload: Annotated[
|
||||
Union[
|
||||
AgenticSystemTurnResponseStepStartPayload,
|
||||
AgenticSystemTurnResponseStepProgressPayload,
|
||||
AgenticSystemTurnResponseStepCompletePayload,
|
||||
AgenticSystemTurnResponseTurnStartPayload,
|
||||
AgenticSystemTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemCreateResponse(BaseModel):
|
||||
class AgentCreateResponse(BaseModel):
|
||||
agent_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemSessionCreateResponse(BaseModel):
|
||||
class AgentSessionCreateResponse(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
|
@ -397,24 +389,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
||||
event: AgenticSystemTurnResponseEvent
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemStepResponse(BaseModel):
|
||||
class AgentStepResponse(BaseModel):
|
||||
step: Step
|
||||
|
||||
|
||||
class AgenticSystem(Protocol):
|
||||
@webmethod(route="/agentic_system/create")
|
||||
async def create_agentic_system(
|
||||
class Agents(Protocol):
|
||||
@webmethod(route="/agents/create")
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgenticSystemCreateResponse: ...
|
||||
) -> AgentCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/create")
|
||||
async def create_agentic_system_turn(
|
||||
@webmethod(route="/agents/turn/create")
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
@ -426,42 +418,40 @@ class AgenticSystem(Protocol):
|
|||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||
) -> AgentTurnResponseStreamChunk: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/get")
|
||||
async def get_agentic_system_turn(
|
||||
@webmethod(route="/agents/turn/get")
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
|
||||
@webmethod(route="/agentic_system/step/get")
|
||||
async def get_agentic_system_step(
|
||||
@webmethod(route="/agents/step/get")
|
||||
async def get_agents_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
) -> AgenticSystemStepResponse: ...
|
||||
) -> AgentStepResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/create")
|
||||
async def create_agentic_system_session(
|
||||
@webmethod(route="/agents/session/create")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgenticSystemSessionCreateResponse: ...
|
||||
) -> AgentSessionCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/get")
|
||||
async def get_agentic_system_session(
|
||||
@webmethod(route="/agents/session/get")
|
||||
async def get_agents_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/delete")
|
||||
async def delete_agentic_system_session(
|
||||
self, agent_id: str, session_id: str
|
||||
) -> None: ...
|
||||
@webmethod(route="/agents/session/delete")
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/delete")
|
||||
async def delete_agentic_system(
|
||||
@webmethod(route="/agents/delete")
|
||||
async def delete_agents(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
|
@ -18,44 +18,42 @@ from termcolor import cprint
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import * # noqa: F403
|
||||
from .agents import * # noqa: F403
|
||||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
||||
return AgenticSystemClient(config.url)
|
||||
return AgentsClient(config.url)
|
||||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
return json.loads(d.json())
|
||||
|
||||
|
||||
class AgenticSystemClient(AgenticSystem):
|
||||
class AgentsClient(Agents):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def create_agentic_system(
|
||||
self, agent_config: AgentConfig
|
||||
) -> AgenticSystemCreateResponse:
|
||||
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agentic_system/create",
|
||||
f"{self.base_url}/agents/create",
|
||||
json={
|
||||
"agent_config": encodable_dict(agent_config),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return AgenticSystemCreateResponse(**response.json())
|
||||
return AgentCreateResponse(**response.json())
|
||||
|
||||
async def create_agentic_system_session(
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgenticSystemSessionCreateResponse:
|
||||
) -> AgentSessionCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agentic_system/session/create",
|
||||
f"{self.base_url}/agents/session/create",
|
||||
json={
|
||||
"agent_id": agent_id,
|
||||
"session_name": session_name,
|
||||
|
@ -63,16 +61,16 @@ class AgenticSystemClient(AgenticSystem):
|
|||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return AgenticSystemSessionCreateResponse(**response.json())
|
||||
return AgentSessionCreateResponse(**response.json())
|
||||
|
||||
async def create_agentic_system_turn(
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
request: AgenticSystemTurnCreateRequest,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/agentic_system/turn/create",
|
||||
f"{self.base_url}/agents/turn/create",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
|
@ -86,7 +84,7 @@ class AgenticSystemClient(AgenticSystem):
|
|||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(**jdata)
|
||||
yield AgentTurnResponseStreamChunk(**jdata)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
@ -102,16 +100,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
|||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
|
||||
create_response = await api.create_agentic_system(agent_config)
|
||||
session_response = await api.create_agentic_system_session(
|
||||
create_response = await api.create_agent(agent_config)
|
||||
session_response = await api.create_agent_session(
|
||||
agent_id=create_response.agent_id,
|
||||
session_name="test_session",
|
||||
)
|
||||
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||
iterator = api.create_agentic_system_turn(
|
||||
AgenticSystemTurnCreateRequest(
|
||||
iterator = api.create_agent_turn(
|
||||
AgentTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
session_id=session_response.session_id,
|
||||
messages=[
|
||||
|
@ -128,7 +126,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
|||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
SearchToolDefinition(engine=SearchEngineType.bing),
|
||||
|
@ -165,7 +163,7 @@ async def run_main(host: str, port: int):
|
|||
|
||||
|
||||
async def run_rag(host: str, port: int):
|
||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
7
llama_stack/apis/batch_inference/__init__.py
Normal file
7
llama_stack/apis/batch_inference/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batch_inference import * # noqa: F401 F403
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
7
llama_stack/apis/dataset/__init__.py
Normal file
7
llama_stack/apis/dataset/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .dataset import * # noqa: F401 F403
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .api import * # noqa: F401 F403
|
||||
from .evals import * # noqa: F401 F403
|
|
@ -12,7 +12,7 @@ from llama_models.schema_utils import webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.dataset.api import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
from llama_stack.common.training_types import * # noqa: F403
|
||||
|
||||
|
7
llama_stack/apis/inference/__init__.py
Normal file
7
llama_stack/apis/inference/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inference import * # noqa: F401 F403
|
|
@ -10,12 +10,14 @@ from typing import Any, AsyncGenerator
|
|||
|
||||
import fire
|
||||
import httpx
|
||||
|
||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||
from .event_logger import EventLogger
|
||||
|
||||
from .api import (
|
||||
from .inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
|
@ -23,7 +25,6 @@ from .api import (
|
|||
Inference,
|
||||
UserMessage,
|
||||
)
|
||||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
7
llama_stack/apis/memory/__init__.py
Normal file
7
llama_stack/apis/memory/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .memory import * # noqa: F401 F403
|
|
@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||
from termcolor import cprint
|
||||
|
||||
from .api import * # noqa: F403
|
||||
from .memory import * # noqa: F403
|
||||
from .common.file_utils import data_url_from_file
|
||||
|
||||
|
7
llama_stack/apis/models/__init__.py
Normal file
7
llama_stack/apis/models/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .models import * # noqa: F401 F403
|
7
llama_stack/apis/post_training/__init__.py
Normal file
7
llama_stack/apis/post_training/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .post_training import * # noqa: F401 F403
|
|
@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.dataset.api import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
from llama_stack.common.training_types import * # noqa: F403
|
||||
|
||||
|
7
llama_stack/apis/reward_scoring/__init__.py
Normal file
7
llama_stack/apis/reward_scoring/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .reward_scoring import * # noqa: F401 F403
|
7
llama_stack/apis/safety/__init__.py
Normal file
7
llama_stack/apis/safety/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .safety import * # noqa: F401 F403
|
|
@ -13,12 +13,12 @@ import fire
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import UserMessage
|
||||
|
||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import * # noqa: F403
|
||||
from .safety import * # noqa: F403
|
||||
|
||||
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
7
llama_stack/apis/synthetic_data_generation/__init__.py
Normal file
7
llama_stack/apis/synthetic_data_generation/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .synthetic_data_generation import * # noqa: F401 F403
|
|
@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.reward_scoring.api import * # noqa: F403
|
||||
from llama_stack.apis.reward_scoring import * # noqa: F403
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
7
llama_stack/apis/telemetry/__init__.py
Normal file
7
llama_stack/apis/telemetry/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .telemetry import * # noqa: F401 F403
|
|
@ -31,16 +31,6 @@ class LlamaCLIParser:
|
|||
ModelParser.create(subparsers)
|
||||
StackParser.create(subparsers)
|
||||
|
||||
# Import sub-commands from agentic_system if they exist
|
||||
try:
|
||||
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
|
||||
|
||||
for module in SUBCOMMAND_MODULES:
|
||||
module.create(subparsers)
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return self.parser.parse_args()
|
||||
|
||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
|||
inference: meta-reference
|
||||
memory: meta-reference-faiss
|
||||
safety: meta-reference
|
||||
agentic_system: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: console
|
||||
image_type: conda
|
||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
|||
inference: remote::fireworks
|
||||
memory: meta-reference-faiss
|
||||
safety: meta-reference
|
||||
agentic_system: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: console
|
||||
image_type: conda
|
||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
|||
inference: remote::ollama
|
||||
memory: meta-reference-faiss
|
||||
safety: meta-reference
|
||||
agentic_system: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: console
|
||||
image_type: conda
|
||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
|||
inference: remote::tgi
|
||||
memory: meta-reference-faiss
|
||||
safety: meta-reference
|
||||
agentic_system: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: console
|
||||
image_type: conda
|
||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
|||
inference: remote::together
|
||||
memory: meta-reference-faiss
|
||||
safety: meta-reference
|
||||
agentic_system: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: console
|
||||
image_type: conda
|
||||
|
|
|
@ -5,6 +5,6 @@ distribution_spec:
|
|||
inference: meta-reference
|
||||
memory: meta-reference-faiss
|
||||
safety: meta-reference
|
||||
agentic_system: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: console
|
||||
image_type: docker
|
||||
|
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
|
|||
class Api(Enum):
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agentic_system = "agentic_system"
|
||||
agents = "agents"
|
||||
memory = "memory"
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
|
|
@ -8,11 +8,11 @@ import importlib
|
|||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_stack.agentic_system.api import AgenticSystem
|
||||
from llama_stack.inference.api import Inference
|
||||
from llama_stack.memory.api import Memory
|
||||
from llama_stack.safety.api import Safety
|
||||
from llama_stack.telemetry.api import Telemetry
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
||||
|
||||
|
@ -34,7 +34,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
protocols = {
|
||||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agentic_system: AgenticSystem,
|
||||
Api.agents: Agents,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
ret = {}
|
||||
for api in stack_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.{name}.providers")
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_id: a for a in module.available_providers()},
|
||||
|
|
|
@ -39,7 +39,7 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.telemetry.tracing import (
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .api import * # noqa: F401 F403
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .api import * # noqa: F401 F403
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .api import * # noqa: F401 F403
|
|
@ -6,15 +6,16 @@
|
|||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.prepare_messages import prepare_messages
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
@ -81,7 +82,7 @@ class FireworksInferenceAdapter(Inference):
|
|||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = list(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -91,7 +92,7 @@ class FireworksInferenceAdapter(Inference):
|
|||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
|
@ -12,10 +12,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
|||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
|
@ -89,7 +90,7 @@ class OllamaInferenceAdapter(Inference):
|
|||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = list(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -99,7 +100,7 @@ class OllamaInferenceAdapter(Inference):
|
|||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
|
@ -13,8 +13,8 @@ from huggingface_hub import HfApi, InferenceClient
|
|||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import TGIImplConfig
|
||||
|
||||
|
@ -87,7 +87,7 @@ class TGIAdapter(Inference):
|
|||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = list(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -97,7 +97,7 @@ class TGIAdapter(Inference):
|
|||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
|
@ -11,10 +11,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
|||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
@ -81,7 +82,7 @@ class TogetherInferenceAdapter(Inference):
|
|||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = list(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(Inference):
|
|||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
|
@ -12,10 +12,13 @@ from urllib.parse import urlparse
|
|||
import chromadb
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.memory.api import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.memory.common.vector_store import BankWithIndex, EmbeddingIndex
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
|
||||
class ChromaIndex(EmbeddingIndex):
|
|
@ -13,10 +13,10 @@ from numpy.typing import NDArray
|
|||
from psycopg2 import sql
|
||||
from psycopg2.extras import execute_values, Json
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.memory.api import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.memory.common.vector_store import (
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
|
@ -14,13 +14,13 @@ from .config import MetaReferenceImplConfig
|
|||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from .agentic_system import MetaReferenceAgenticSystemImpl
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceAgenticSystemImpl(
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.memory],
|
|
@ -20,10 +20,10 @@ import httpx
|
|||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.agentic_system.api import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.memory.api import * # noqa: F403
|
||||
from llama_stack.safety.api import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.tools.base import BaseTool
|
||||
from llama_stack.tools.builtin import (
|
||||
|
@ -122,7 +122,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return session
|
||||
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgenticSystemTurnCreateRequest
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert (
|
||||
request.session_id in self.sessions
|
||||
|
@ -141,9 +141,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
turn_id = str(uuid.uuid4())
|
||||
start_time = datetime.now()
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseTurnStartPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
|
@ -169,12 +169,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
continue
|
||||
|
||||
assert isinstance(
|
||||
chunk, AgenticSystemTurnResponseStreamChunk
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgenticSystemTurnResponseEventType.step_complete.value
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
|
@ -193,9 +193,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
session.turns.append(turn)
|
||||
|
||||
chunk = AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseTurnCompletePayload(
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
|
@ -261,9 +261,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
|
@ -273,9 +273,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
await self.run_shields(messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
|
@ -292,9 +292,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
yield False
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
|
@ -325,9 +325,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
if need_rag_context:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
@ -341,9 +341,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
step_details=MemoryRetrievalStep(
|
||||
|
@ -360,7 +360,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
last_message = input_messages[-1]
|
||||
last_message.context = "\n".join(rag_context)
|
||||
|
||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
msg = await attachment_message(self.tempdir, urls)
|
||||
input_messages.append(msg)
|
||||
|
@ -379,9 +379,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
cprint(f"{str(msg)}", color=color)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
@ -412,9 +412,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_calls.append(delta.content)
|
||||
|
||||
if stream:
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepProgressPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta="",
|
||||
|
@ -426,9 +426,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
elif isinstance(delta, str):
|
||||
content += delta
|
||||
if stream and event.stop_reason is None:
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepProgressPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta=event.delta,
|
||||
|
@ -448,9 +448,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
step_details=InferenceStep(
|
||||
|
@ -498,17 +498,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepProgressPayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
|
@ -525,9 +525,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
|
@ -547,9 +547,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
|
@ -566,9 +566,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
|
@ -616,18 +616,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
if attachments:
|
||||
if (
|
||||
AgenticSystemTool.code_interpreter.value in enabled_tools
|
||||
AgentTool.code_interpreter.value in enabled_tools
|
||||
and self.agent_config.tool_choice == ToolChoice.required
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return AgenticSystemTool.memory.value in enabled_tools
|
||||
return AgentTool.memory.value in enabled_tools
|
||||
|
||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||
for t in self.agent_config.tools:
|
||||
if t.type == AgenticSystemTool.memory.value:
|
||||
if t.type == AgentTool.memory.value:
|
||||
return t
|
||||
|
||||
return None
|
|
@ -10,10 +10,10 @@ import tempfile
|
|||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_stack.inference.api import Inference
|
||||
from llama_stack.memory.api import Memory
|
||||
from llama_stack.safety.api import Safety
|
||||
from llama_stack.agentic_system.api import * # noqa: F403
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.tools.builtin import (
|
||||
CodeInterpreterTool,
|
||||
PhotogenTool,
|
||||
|
@ -33,7 +33,7 @@ logger.setLevel(logging.INFO)
|
|||
AGENT_INSTANCES_BY_ID = {}
|
||||
|
||||
|
||||
class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceImplConfig,
|
||||
|
@ -49,10 +49,10 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def create_agentic_system(
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgenticSystemCreateResponse:
|
||||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
|
||||
builtin_tools = []
|
||||
|
@ -95,24 +95,24 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
builtin_tools=builtin_tools,
|
||||
)
|
||||
|
||||
return AgenticSystemCreateResponse(
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def create_agentic_system_session(
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgenticSystemSessionCreateResponse:
|
||||
) -> AgentSessionCreateResponse:
|
||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||
|
||||
session = agent.create_session(session_name)
|
||||
return AgenticSystemSessionCreateResponse(
|
||||
return AgentSessionCreateResponse(
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
async def create_agentic_system_turn(
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
@ -126,7 +126,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = AgenticSystemTurnCreateRequest(
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
|
@ -10,14 +10,14 @@ from jinja2 import Template
|
|||
from llama_models.llama3.api import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.agentic_system.api import (
|
||||
from llama_stack.apis.agents import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from termcolor import cprint # noqa: F401
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
||||
async def generate_rag_query(
|
|
@ -7,15 +7,15 @@
|
|||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.safety.api import (
|
||||
from llama_stack.apis.safety import (
|
||||
OnViolationAction,
|
||||
RunShieldRequest,
|
||||
Safety,
|
||||
ShieldDefinition,
|
||||
ShieldResponse,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
|
@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily
|
|||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
from llama_stack.inference.api import QuantizationConfig
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@json_schema_type
|
|
@ -28,10 +28,10 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.common.model_utils import model_local_dir
|
||||
from llama_stack.inference.api import QuantizationType
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
|
@ -11,7 +11,7 @@ from typing import AsyncIterator, Union
|
|||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.inference.api import (
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
|
@ -21,13 +21,13 @@ from llama_stack.inference.api import (
|
|||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.inference.api import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
|
@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = list(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
|
@ -14,9 +14,9 @@ import torch
|
|||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||
from llama_stack.inference.api import QuantizationType
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.inference.api.config import (
|
||||
from llama_stack.apis.inference.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
MetaReferenceImplConfig,
|
||||
)
|
|
@ -15,13 +15,14 @@ from numpy.typing import NDArray
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.memory.api import * # noqa: F403
|
||||
from llama_stack.memory.common.vector_store import (
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
from llama_stack.telemetry import tracing
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .config import FaissImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
|
@ -9,7 +9,7 @@ import asyncio
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.common.model_utils import model_local_dir
|
||||
from llama_stack.safety.api import * # noqa
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
|
||||
from .config import SafetyConfig
|
||||
from .shields import (
|
|
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from llama_stack.safety.api import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
|
@ -8,7 +8,7 @@ from codeshield.cs import CodeShield
|
|||
from termcolor import cprint
|
||||
|
||||
from .base import ShieldResponse, TextShield
|
||||
from llama_stack.safety.api import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_stack.safety.api import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
_INSTANCE = None
|
|
@ -14,7 +14,7 @@ from termcolor import cprint
|
|||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||
from llama_stack.safety.api import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class PromptGuardShield(TextShield):
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_stack.telemetry.api import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from .config import ConsoleConfig
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ from llama_stack.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
|||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.agentic_system,
|
||||
api=Api.agents,
|
||||
provider_id="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
|
@ -23,8 +23,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"torch",
|
||||
"transformers",
|
||||
],
|
||||
module="llama_stack.agentic_system.meta_reference",
|
||||
config_class="llama_stack.agentic_system.meta_reference.MetaReferenceImplConfig",
|
||||
module="llama_stack.providers.impls.meta_reference.agents",
|
||||
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue