mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Make each inference provider into its own subdirectory
This commit is contained in:
parent
f64668319c
commit
0de5a807c7
42 changed files with 123 additions and 103 deletions
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .agentic_system import get_provider_impl # noqa
|
||||||
|
from .config import AgenticSystemConfig # noqa
|
|
@ -10,12 +10,24 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_toolchain.agentic_system.api.datatypes import (
|
||||||
from llama_toolchain.safety.api import Safety
|
AgenticSystemInstanceConfig,
|
||||||
|
AgenticSystemTurnResponseEvent,
|
||||||
|
AgenticSystemTurnResponseEventType,
|
||||||
|
AgenticSystemTurnResponseStepCompletePayload,
|
||||||
|
AgenticSystemTurnResponseStepProgressPayload,
|
||||||
|
AgenticSystemTurnResponseStepStartPayload,
|
||||||
|
AgenticSystemTurnResponseTurnCompletePayload,
|
||||||
|
AgenticSystemTurnResponseTurnStartPayload,
|
||||||
|
InferenceStep,
|
||||||
|
Session,
|
||||||
|
ShieldCallStep,
|
||||||
|
StepType,
|
||||||
|
ToolExecutionStep,
|
||||||
|
Turn,
|
||||||
|
)
|
||||||
|
|
||||||
from .api.endpoints import * # noqa
|
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
|
||||||
|
|
||||||
from llama_toolchain.inference.api import ChatCompletionRequest
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
from llama_toolchain.inference.api.datatypes import (
|
||||||
Attachment,
|
Attachment,
|
||||||
|
@ -33,36 +45,16 @@ from llama_toolchain.inference.api.datatypes import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
URL,
|
URL,
|
||||||
)
|
)
|
||||||
|
from llama_toolchain.safety.api import Safety
|
||||||
from llama_toolchain.safety.api.datatypes import (
|
from llama_toolchain.safety.api.datatypes import (
|
||||||
BuiltinShield,
|
BuiltinShield,
|
||||||
ShieldDefinition,
|
ShieldDefinition,
|
||||||
ShieldResponse,
|
ShieldResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||||
|
|
||||||
from .api.datatypes import (
|
|
||||||
AgenticSystemInstanceConfig,
|
|
||||||
AgenticSystemTurnResponseEvent,
|
|
||||||
AgenticSystemTurnResponseEventType,
|
|
||||||
AgenticSystemTurnResponseStepCompletePayload,
|
|
||||||
AgenticSystemTurnResponseStepProgressPayload,
|
|
||||||
AgenticSystemTurnResponseStepStartPayload,
|
|
||||||
AgenticSystemTurnResponseTurnCompletePayload,
|
|
||||||
AgenticSystemTurnResponseTurnStartPayload,
|
|
||||||
InferenceStep,
|
|
||||||
Session,
|
|
||||||
ShieldCallStep,
|
|
||||||
StepType,
|
|
||||||
ToolExecutionStep,
|
|
||||||
Turn,
|
|
||||||
)
|
|
||||||
from .api.endpoints import (
|
|
||||||
AgenticSystemTurnCreateRequest,
|
|
||||||
AgenticSystemTurnResponseStreamChunk,
|
|
||||||
)
|
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
from .system_prompt import get_agentic_prefix_messages
|
from .system_prompt import get_agentic_prefix_messages
|
||||||
from .tools.base import BaseTool
|
from .tools.base import BaseTool
|
||||||
from .tools.builtin import SingleMessageBuiltinTool
|
from .tools.builtin import SingleMessageBuiltinTool
|
|
@ -5,25 +5,18 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
|
||||||
from llama_toolchain.inference.api import Inference
|
|
||||||
from llama_toolchain.safety.api import Safety
|
|
||||||
|
|
||||||
from .config import AgenticSystemConfig
|
|
||||||
from .api.endpoints import * # noqa
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator, Dict
|
||||||
|
|
||||||
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.inference.api.datatypes import BuiltinTool
|
from llama_toolchain.inference.api.datatypes import BuiltinTool
|
||||||
|
from llama_toolchain.safety.api import Safety
|
||||||
from .agent_instance import AgentInstance
|
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||||
|
from llama_toolchain.agentic_system.api import (
|
||||||
from .api.endpoints import (
|
AgenticSystem,
|
||||||
AgenticSystemCreateRequest,
|
AgenticSystemCreateRequest,
|
||||||
AgenticSystemCreateResponse,
|
AgenticSystemCreateResponse,
|
||||||
AgenticSystemSessionCreateRequest,
|
AgenticSystemSessionCreateRequest,
|
||||||
|
@ -31,6 +24,10 @@ from .api.endpoints import (
|
||||||
AgenticSystemTurnCreateRequest,
|
AgenticSystemTurnCreateRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .agent_instance import AgentInstance
|
||||||
|
|
||||||
|
from .config import AgenticSystemConfig
|
||||||
|
|
||||||
from .tools.builtin import (
|
from .tools.builtin import (
|
||||||
BraveSearchTool,
|
BraveSearchTool,
|
||||||
CodeInterpreterTool,
|
CodeInterpreterTool,
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.safety import ShieldRunnerMixin
|
from llama_toolchain.agentic_system.meta_reference.safety import ShieldRunnerMixin
|
||||||
|
|
||||||
from llama_toolchain.inference.api import Message
|
from llama_toolchain.inference.api import Message
|
||||||
from llama_toolchain.safety.api.datatypes import ShieldDefinition
|
from llama_toolchain.safety.api.datatypes import ShieldDefinition
|
|
@ -19,8 +19,8 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||||
"torch",
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
],
|
],
|
||||||
module="llama_toolchain.agentic_system.agentic_system",
|
module="llama_toolchain.agentic_system.meta_reference",
|
||||||
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
|
config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
Api.safety,
|
Api.safety,
|
||||||
|
|
5
llama_toolchain/agentic_system/tools/custom/__init__.py
Normal file
5
llama_toolchain/agentic_system/tools/custom/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -12,7 +12,10 @@ from typing import Dict, List
|
||||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
|
|
||||||
from .builtin import interpret_content_as_attachment
|
# TODO: this is symptomatic of us needing to pull more tooling related utilities
|
||||||
|
from llama_toolchain.agentic_system.meta_reference.tools.builtin import (
|
||||||
|
interpret_content_as_attachment,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CustomTool:
|
class CustomTool:
|
|
@ -17,10 +17,15 @@ from llama_toolchain.agentic_system.api import (
|
||||||
)
|
)
|
||||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.tools.execute import execute_with_custom_tools
|
from llama_toolchain.agentic_system.tools.custom.execute import (
|
||||||
|
execute_with_custom_tools,
|
||||||
|
)
|
||||||
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this should move back to the llama-agentic-system repo
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClientWrapper:
|
class AgenticSystemClientWrapper:
|
||||||
|
|
||||||
def __init__(self, api, system_id, custom_tools):
|
def __init__(self, api, system_id, custom_tools):
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
# from .api.config import ImplType, InferenceConfig
|
|
||||||
|
|
||||||
|
|
||||||
# async def get_inference_api_instance(config: InferenceConfig):
|
|
||||||
# if config.impl_config.impl_type == ImplType.inline.value:
|
|
||||||
# from .inference import InferenceImpl
|
|
||||||
|
|
||||||
# return InferenceImpl(config.impl_config)
|
|
||||||
# elif config.impl_config.impl_type == ImplType.ollama.value:
|
|
||||||
# from .ollama import OllamaInference
|
|
||||||
|
|
||||||
# return OllamaInference(config.impl_config)
|
|
||||||
|
|
||||||
# from .client import InferenceClient
|
|
||||||
|
|
||||||
# return InferenceClient(config.impl_config.url)
|
|
8
llama_toolchain/inference/meta_reference/__init__.py
Normal file
8
llama_toolchain/inference/meta_reference/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import MetaReferenceImplConfig # noqa
|
||||||
|
from .inference import get_provider_impl # noqa
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
from strong_typing.schema import json_schema_type
|
from strong_typing.schema import json_schema_type
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .datatypes import QuantizationConfig
|
from llama_toolchain.inference.api import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -63,9 +63,3 @@ class MetaReferenceImplConfig(BaseModel):
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OllamaImplConfig(BaseModel):
|
|
||||||
model: str = Field(..., description="The name of the model in ollama catalog")
|
|
||||||
url: str = Field(..., description="The URL for the ollama server")
|
|
|
@ -29,8 +29,9 @@ from llama_models.llama3_1.api.model import Transformer
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api.config import CheckpointType, MetaReferenceImplConfig
|
from llama_toolchain.inference.api import QuantizationType
|
||||||
from .api.datatypes import QuantizationType
|
|
||||||
|
from .config import CheckpointType, MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
|
@ -12,20 +12,18 @@ from llama_models.llama3_1.api.datatypes import StopReason
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
from llama_toolchain.inference.api import (
|
||||||
from .api.config import MetaReferenceImplConfig
|
ChatCompletionRequest,
|
||||||
from .api.datatypes import (
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
Inference,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from .api.endpoints import (
|
|
||||||
ChatCompletionRequest,
|
from .config import MetaReferenceImplConfig
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
Inference,
|
|
||||||
)
|
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from .api.config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
8
llama_toolchain/inference/ollama/__init__.py
Normal file
8
llama_toolchain/inference/ollama/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import OllamaImplConfig # noqa
|
||||||
|
from .ollama import get_provider_impl # noqa
|
14
llama_toolchain/inference/ollama/config.py
Normal file
14
llama_toolchain/inference/ollama/config.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OllamaImplConfig(BaseModel):
|
||||||
|
model: str = Field(..., description="The name of the model in ollama catalog")
|
||||||
|
url: str = Field(..., description="The URL for the ollama server")
|
|
@ -22,21 +22,20 @@ from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from .api.config import OllamaImplConfig
|
from llama_toolchain.inference.api import (
|
||||||
from .api.datatypes import (
|
|
||||||
ChatCompletionResponseEvent,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
|
||||||
from .api.endpoints import (
|
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
Inference,
|
Inference,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .config import OllamaImplConfig
|
||||||
|
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
OLLAMA_SUPPORTED_SKUS = {
|
OLLAMA_SUPPORTED_SKUS = {
|
|
@ -18,8 +18,8 @@ def available_inference_providers() -> List[ProviderSpec]:
|
||||||
"torch",
|
"torch",
|
||||||
"zmq",
|
"zmq",
|
||||||
],
|
],
|
||||||
module="llama_toolchain.inference.inference",
|
module="llama_toolchain.inference.meta_reference",
|
||||||
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
|
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
|
|
8
llama_toolchain/safety/meta_reference/__init__.py
Normal file
8
llama_toolchain/safety/meta_reference/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import SafetyConfig # noqa
|
||||||
|
from .safety import get_provider_impl # noqa
|
|
@ -11,7 +11,7 @@ from typing import Dict
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
from .api.endpoints import * # noqa
|
from llama_toolchain.safety.api import * # noqa
|
||||||
from .shields import (
|
from .shields import (
|
||||||
CodeScannerShield,
|
CodeScannerShield,
|
||||||
InjectionShield,
|
InjectionShield,
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -4,14 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import sys
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message
|
||||||
|
|
||||||
parent_dir = "../.."
|
from llama_toolchain.safety.meta_reference.shields.base import (
|
||||||
sys.path.append(parent_dir)
|
|
||||||
from llama_toolchain.safety.shields.base import (
|
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
ShieldBase,
|
ShieldBase,
|
||||||
ShieldResponse,
|
ShieldResponse,
|
|
@ -19,7 +19,7 @@ def available_safety_providers() -> List[ProviderSpec]:
|
||||||
"torch",
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
],
|
],
|
||||||
module="llama_toolchain.safety.safety",
|
module="llama_toolchain.safety.meta_reference",
|
||||||
config_class="llama_toolchain.safety.config.SafetyConfig",
|
config_class="llama_toolchain.safety.meta_reference.SafetyConfig",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue