Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent
This commit is contained in:
Ashwin Bharambe 2024-09-16 17:34:07 -07:00
parent 2cf731faea
commit 76b354a081
128 changed files with 381 additions and 376 deletions

View file

@ -482,7 +482,7 @@ Once the server is setup, we can test it with a client to see the example output
cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work
python -m llama_stack.inference.client localhost 5000
python -m llama_stack.apis.inference.client localhost 5000
```
This will run the chat completion client and query the distributions /inference/chat_completion API.

View file

@ -296,7 +296,7 @@ Once the server is setup, we can test it with a client to see the example output
cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work
python -m llama_stack.inference.client localhost 5000
python -m llama_stack.apis.inference.client localhost 5000
```
This will run the chat completion client and query the distributions /inference/chat_completion API.
@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
```
python -m llama_stack.safety.client localhost 5000
python -m llama_stack.apis.safety.client localhost 5000
```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import * # noqa: F401 F403

View file

@ -15,9 +15,9 @@ from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.common.deployment_types import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.safety.api import * # noqa: F403
from llama_stack.memory.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
@json_schema_type
@ -26,7 +26,7 @@ class Attachment(BaseModel):
mime_type: str
class AgenticSystemTool(Enum):
class AgentTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
@ -50,41 +50,33 @@ class SearchEngineType(Enum):
class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name
type: Literal[AgenticSystemTool.brave_search.value] = (
AgenticSystemTool.brave_search.value
)
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
AgenticSystemTool.wolfram_alpha.value
)
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.code_interpreter.value] = (
AgenticSystemTool.code_interpreter.value
)
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.function_call.value] = (
AgenticSystemTool.function_call.value
)
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
function_name: str
description: str
parameters: Dict[str, ToolParamDefinition]
@ -95,30 +87,30 @@ class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgenticSystemVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig,
AgentVectorMemoryBankConfig,
AgentKeyValueMemoryBankConfig,
AgentKeywordMemoryBankConfig,
AgentGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@ -158,7 +150,7 @@ MemoryQueryGeneratorConfig = Annotated[
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
@ -169,7 +161,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[
AgentToolDefinition = Annotated[
Union[
SearchToolDefinition,
WolframAlphaToolDefinition,
@ -275,7 +267,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -292,7 +284,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum):
class AgentTurnResponseEventType(Enum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
@ -302,9 +294,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
AgentTurnResponseEventType.step_start.value
)
step_type: StepType
step_id: str
@ -312,20 +304,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
AgentTurnResponseEventType.step_complete.value
)
step_type: StepType
step_details: Step
@json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
AgentTurnResponseEventType.step_progress.value
)
step_type: StepType
step_id: str
@ -336,49 +328,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
AgentTurnResponseEventType.turn_start.value
)
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
AgentTurnResponseEventType.turn_complete.value
)
turn: Turn
@json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel):
class AgentTurnResponseEvent(BaseModel):
"""Streamed agent execution response."""
payload: Annotated[
Union[
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
],
Field(discriminator="event_type"),
]
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
class AgentCreateResponse(BaseModel):
agent_id: str
@json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel):
class AgentSessionCreateResponse(BaseModel):
session_id: str
@json_schema_type
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str
session_id: str
@ -397,24 +389,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
@json_schema_type
class AgenticSystemTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent
class AgentTurnResponseStreamChunk(BaseModel):
event: AgentTurnResponseEvent
@json_schema_type
class AgenticSystemStepResponse(BaseModel):
class AgentStepResponse(BaseModel):
step: Step
class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
class Agents(Protocol):
@webmethod(route="/agents/create")
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ...
) -> AgentCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
@webmethod(route="/agents/turn/create")
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
@ -426,42 +418,40 @@ class AgenticSystem(Protocol):
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AgenticSystemTurnResponseStreamChunk: ...
) -> AgentTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
@webmethod(route="/agents/turn/get")
async def get_agents_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
@webmethod(route="/agents/step/get")
async def get_agents_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ...
) -> AgentStepResponse: ...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
@webmethod(route="/agents/session/create")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse: ...
) -> AgentSessionCreateResponse: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
@webmethod(route="/agents/session/get")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agents/session/delete")
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
@webmethod(route="/agents/delete")
async def delete_agents(
self,
agent_id: str,
) -> None: ...

View file

@ -18,44 +18,42 @@ from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
from .agents import * # noqa: F403
from .event_logger import EventLogger
async def get_client_impl(config: RemoteProviderConfig, _deps):
return AgenticSystemClient(config.url)
return AgentsClient(config.url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class AgenticSystemClient(AgenticSystem):
class AgentsClient(Agents):
def __init__(self, base_url: str):
self.base_url = base_url
async def create_agentic_system(
self, agent_config: AgentConfig
) -> AgenticSystemCreateResponse:
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/create",
f"{self.base_url}/agents/create",
json={
"agent_config": encodable_dict(agent_config),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return AgenticSystemCreateResponse(**response.json())
return AgentCreateResponse(**response.json())
async def create_agentic_system_session(
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse:
) -> AgentSessionCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/session/create",
f"{self.base_url}/agents/session/create",
json={
"agent_id": agent_id,
"session_name": session_name,
@ -63,16 +61,16 @@ class AgenticSystemClient(AgenticSystem):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return AgenticSystemSessionCreateResponse(**response.json())
return AgentSessionCreateResponse(**response.json())
async def create_agentic_system_turn(
async def create_agent_turn(
self,
request: AgenticSystemTurnCreateRequest,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/agentic_system/turn/create",
f"{self.base_url}/agents/turn/create",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
@ -86,7 +84,7 @@ class AgenticSystemClient(AgenticSystem):
cprint(data, "red")
continue
yield AgenticSystemTurnResponseStreamChunk(**jdata)
yield AgentTurnResponseStreamChunk(**jdata)
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
@ -102,16 +100,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
tool_prompt_format=ToolPromptFormat.function_tag,
)
create_response = await api.create_agentic_system(agent_config)
session_response = await api.create_agentic_system_session(
create_response = await api.create_agent(agent_config)
session_response = await api.create_agent_session(
agent_id=create_response.agent_id,
session_name="test_session",
)
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
iterator = api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,
messages=[
@ -128,7 +126,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}")
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
SearchToolDefinition(engine=SearchEngineType.bing),
@ -165,7 +163,7 @@ async def run_main(host: str, port: int):
async def run_rag(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}")
api = AgentsClient(f"http://{host}:{port}")
urls = [
"memory_optimizations.rst",

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batch_inference import * # noqa: F401 F403

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@json_schema_type

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dataset import * # noqa: F401 F403

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403
from .evals import * # noqa: F401 F403

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.dataset.api import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.common.training_types import * # noqa: F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import * # noqa: F401 F403

View file

@ -10,12 +10,14 @@ from typing import Any, AsyncGenerator
import fire
import httpx
from llama_stack.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.core.datatypes import RemoteProviderConfig
from .event_logger import EventLogger
from .api import (
from .inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
@ -23,7 +25,6 @@ from .api import (
Inference,
UserMessage,
)
from .event_logger import EventLogger
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .memory import * # noqa: F401 F403

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.core.datatypes import RemoteProviderConfig
from termcolor import cprint
from .api import * # noqa: F403
from .memory import * # noqa: F403
from .common.file_utils import data_url_from_file

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import * # noqa: F401 F403

View file

@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.dataset.api import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.common.training_types import * # noqa: F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .reward_scoring import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import * # noqa: F401 F403

View file

@ -13,12 +13,12 @@ import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from llama_stack.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
from .safety import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import * # noqa: F401 F403

View file

@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.reward_scoring.api import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
class FilteringFunction(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import * # noqa: F401 F403

View file

@ -31,16 +31,6 @@ class LlamaCLIParser:
ModelParser.create(subparsers)
StackParser.create(subparsers)
# Import sub-commands from agentic_system if they exist
try:
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
for module in SUBCOMMAND_MODULES:
module.create(subparsers)
except ImportError:
pass
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: meta-reference
memory: meta-reference-faiss
safety: meta-reference
agentic_system: meta-reference
agents: meta-reference
telemetry: console
image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::fireworks
memory: meta-reference-faiss
safety: meta-reference
agentic_system: meta-reference
agents: meta-reference
telemetry: console
image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::ollama
memory: meta-reference-faiss
safety: meta-reference
agentic_system: meta-reference
agents: meta-reference
telemetry: console
image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::tgi
memory: meta-reference-faiss
safety: meta-reference
agentic_system: meta-reference
agents: meta-reference
telemetry: console
image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: remote::together
memory: meta-reference-faiss
safety: meta-reference
agentic_system: meta-reference
agents: meta-reference
telemetry: console
image_type: conda

View file

@ -5,6 +5,6 @@ distribution_spec:
inference: meta-reference
memory: meta-reference-faiss
safety: meta-reference
agentic_system: meta-reference
agents: meta-reference
telemetry: console
image_type: docker

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
agents = "agents"
memory = "memory"
telemetry = "telemetry"

View file

@ -8,11 +8,11 @@ import importlib
import inspect
from typing import Dict, List
from llama_stack.agentic_system.api import AgenticSystem
from llama_stack.inference.api import Inference
from llama_stack.memory.api import Memory
from llama_stack.safety.api import Safety
from llama_stack.telemetry.api import Telemetry
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
@ -34,7 +34,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
}
@ -67,7 +67,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
for api in stack_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.{name}.providers")
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_id: a for a in module.available_providers()},

View file

@ -39,7 +39,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.telemetry.tracing import (
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -6,15 +6,16 @@
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages
from fireworks.client import Fireworks
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import FireworksImplConfig
@ -81,7 +82,7 @@ class FireworksInferenceAdapter(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -91,7 +92,7 @@ class FireworksInferenceAdapter(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,

View file

@ -12,10 +12,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
@ -89,7 +90,7 @@ class OllamaInferenceAdapter(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -99,7 +100,7 @@ class OllamaInferenceAdapter(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,

View file

@ -13,8 +13,8 @@ from huggingface_hub import HfApi, InferenceClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import TGIImplConfig
@ -87,7 +87,7 @@ class TGIAdapter(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -97,7 +97,7 @@ class TGIAdapter(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,

View file

@ -11,10 +11,11 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from together import Together
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import TogetherImplConfig
@ -81,7 +82,7 @@ class TogetherInferenceAdapter(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,

View file

@ -12,10 +12,13 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.memory.api import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.memory.common.vector_store import BankWithIndex, EmbeddingIndex
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
class ChromaIndex(EmbeddingIndex):

View file

@ -13,10 +13,10 @@ from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.memory.api import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.memory.common.vector_store import (
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,

View file

@ -14,13 +14,13 @@ from .config import MetaReferenceImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agentic_system import MetaReferenceAgenticSystemImpl
from .agents import MetaReferenceAgentsImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
impl = MetaReferenceAgentsImpl(
config,
deps[Api.inference],
deps[Api.memory],

View file

@ -20,10 +20,10 @@ import httpx
from termcolor import cprint
from llama_stack.agentic_system.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.memory.api import * # noqa: F403
from llama_stack.safety.api import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.tools.base import BaseTool
from llama_stack.tools.builtin import (
@ -122,7 +122,7 @@ class ChatAgent(ShieldRunnerMixin):
return session
async def create_and_execute_turn(
self, request: AgenticSystemTurnCreateRequest
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert (
request.session_id in self.sessions
@ -141,9 +141,9 @@ class ChatAgent(ShieldRunnerMixin):
turn_id = str(uuid.uuid4())
start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseTurnStartPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
@ -169,12 +169,12 @@ class ChatAgent(ShieldRunnerMixin):
continue
assert isinstance(
chunk, AgenticSystemTurnResponseStreamChunk
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgenticSystemTurnResponseEventType.step_complete.value
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
@ -193,9 +193,9 @@ class ChatAgent(ShieldRunnerMixin):
)
session.turns.append(turn)
chunk = AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseTurnCompletePayload(
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
@ -261,9 +261,9 @@ class ChatAgent(ShieldRunnerMixin):
step_id = str(uuid.uuid4())
try:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
@ -273,9 +273,9 @@ class ChatAgent(ShieldRunnerMixin):
await self.run_shields(messages, shields)
except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
@ -292,9 +292,9 @@ class ChatAgent(ShieldRunnerMixin):
)
yield False
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
@ -325,9 +325,9 @@ class ChatAgent(ShieldRunnerMixin):
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
)
@ -341,9 +341,9 @@ class ChatAgent(ShieldRunnerMixin):
)
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
@ -360,7 +360,7 @@ class ChatAgent(ShieldRunnerMixin):
last_message = input_messages[-1]
last_message.context = "\n".join(rag_context)
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg)
@ -379,9 +379,9 @@ class ChatAgent(ShieldRunnerMixin):
cprint(f"{str(msg)}", color=color)
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value,
step_id=step_id,
)
@ -412,9 +412,9 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls.append(delta.content)
if stream:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
@ -426,9 +426,9 @@ class ChatAgent(ShieldRunnerMixin):
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
@ -448,9 +448,9 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls=tool_calls,
)
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value,
step_id=step_id,
step_details=InferenceStep(
@ -498,17 +498,17 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
@ -525,9 +525,9 @@ class ChatAgent(ShieldRunnerMixin):
), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_details=ToolExecutionStep(
step_id=step_id,
@ -547,9 +547,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
@ -566,9 +566,9 @@ class ChatAgent(ShieldRunnerMixin):
)
except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
@ -616,18 +616,18 @@ class ChatAgent(ShieldRunnerMixin):
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgenticSystemTool.code_interpreter.value in enabled_tools
AgentTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
else:
return True
return AgenticSystemTool.memory.value in enabled_tools
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgenticSystemTool.memory.value:
if t.type == AgentTool.memory.value:
return t
return None

View file

@ -10,10 +10,10 @@ import tempfile
import uuid
from typing import AsyncGenerator
from llama_stack.inference.api import Inference
from llama_stack.memory.api import Memory
from llama_stack.safety.api import Safety
from llama_stack.agentic_system.api import * # noqa: F403
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.tools.builtin import (
CodeInterpreterTool,
PhotogenTool,
@ -33,7 +33,7 @@ logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgenticSystemImpl(AgenticSystem):
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
config: MetaReferenceImplConfig,
@ -49,10 +49,10 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
async def initialize(self) -> None:
pass
async def create_agentic_system(
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgenticSystemCreateResponse:
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
builtin_tools = []
@ -95,24 +95,24 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
builtin_tools=builtin_tools,
)
return AgenticSystemCreateResponse(
return AgentCreateResponse(
agent_id=agent_id,
)
async def create_agentic_system_session(
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse:
) -> AgentSessionCreateResponse:
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
session = agent.create_session(session_name)
return AgenticSystemSessionCreateResponse(
return AgentSessionCreateResponse(
session_id=session.session_id,
)
async def create_agentic_system_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
@ -126,7 +126,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
stream: Optional[bool] = False,
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgenticSystemTurnCreateRequest(
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,

View file

@ -10,14 +10,14 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from llama_stack.agentic_system.api import (
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from termcolor import cprint # noqa: F401
from llama_stack.inference.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
async def generate_rag_query(

View file

@ -7,15 +7,15 @@
from typing import List
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_stack.safety.api import (
from llama_stack.apis.safety import (
OnViolationAction,
RunShieldRequest,
Safety,
ShieldDefinition,
ShieldResponse,
)
from termcolor import cprint
class SafetyException(Exception): # noqa: N818

View file

@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models, resolve_model
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.inference.api import QuantizationConfig
from pydantic import BaseModel, Field, field_validator
@json_schema_type

View file

@ -28,10 +28,10 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_stack.apis.inference import QuantizationType
from llama_stack.common.model_utils import model_local_dir
from llama_stack.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig

View file

@ -11,7 +11,7 @@ from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_stack.inference.api import (
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -21,13 +21,13 @@ from llama_stack.inference.api import (
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,

View file

@ -14,9 +14,9 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_stack.inference.api import QuantizationType
from llama_stack.apis.inference import QuantizationType
from llama_stack.inference.api.config import (
from llama_stack.apis.inference.config import (
CheckpointQuantizationFormat,
MetaReferenceImplConfig,
)

View file

@ -15,13 +15,14 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.memory.api import * # noqa: F403
from llama_stack.memory.common.vector_store import (
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from llama_stack.telemetry import tracing
from llama_stack.providers.utils.telemetry import tracing
from .config import FaissImplConfig
logger = logging.getLogger(__name__)

View file

@ -9,7 +9,7 @@ import asyncio
from llama_models.sku_list import resolve_model
from llama_stack.common.model_utils import model_local_dir
from llama_stack.safety.api import * # noqa
from llama_stack.apis.safety import * # noqa
from .config import SafetyConfig
from .shields import (

View file

@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from llama_stack.safety.api import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -8,7 +8,7 @@ from codeshield.cs import CodeShield
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_stack.safety.api import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
class CodeScannerShield(TextShield):

View file

@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.safety.api import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None

View file

@ -14,7 +14,7 @@ from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.safety.api import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
class PromptGuardShield(TextShield):

View file

@ -6,7 +6,7 @@
from typing import Optional
from llama_stack.telemetry.api import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from .config import ConsoleConfig

View file

@ -12,7 +12,7 @@ from llama_stack.core.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.agentic_system,
api=Api.agents,
provider_id="meta-reference",
pip_packages=[
"codeshield",
@ -23,8 +23,8 @@ def available_providers() -> List[ProviderSpec]:
"torch",
"transformers",
],
module="llama_stack.agentic_system.meta_reference",
config_class="llama_stack.agentic_system.meta_reference.MetaReferenceImplConfig",
module="llama_stack.providers.impls.meta_reference.agents",
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig",
api_dependencies=[
Api.inference,
Api.safety,

Some files were not shown because too many files have changed in this diff Show more