From 0de5a807c7d76f13ced14ddc3c6fa7e735ffa0d4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 5 Aug 2024 15:13:52 -0700 Subject: [PATCH] Make each inference provider into its own subdirectory --- .../agentic_system/meta_reference/__init__.py | 8 ++++ .../{ => meta_reference}/agent_instance.py | 46 ++++++++----------- .../{ => meta_reference}/agentic_system.py | 23 ++++------ .../{ => meta_reference}/config.py | 0 .../{ => meta_reference}/safety.py | 0 .../{ => meta_reference}/system_prompt.py | 0 .../tools}/__init__.py | 0 .../{ => meta_reference}/tools/base.py | 0 .../{ => meta_reference}/tools/builtin.py | 0 .../tools/ipython_tool}/__init__.py | 0 .../tools/ipython_tool/code_env_prefix.py | 0 .../tools/ipython_tool/code_execution.py | 0 .../ipython_tool/matplotlib_custom_backend.py | 0 .../tools/ipython_tool/utils.py | 0 .../{ => meta_reference}/tools/safety.py | 2 +- llama_toolchain/agentic_system/providers.py | 4 +- .../agentic_system/tools/custom/__init__.py | 5 ++ .../tools/{custom.py => custom/datatypes.py} | 5 +- .../tools/{ => custom}/execute.py | 0 llama_toolchain/agentic_system/utils.py | 7 ++- llama_toolchain/inference/api_instance.py | 22 --------- .../inference/meta_reference/__init__.py | 8 ++++ .../{api => meta_reference}/config.py | 8 +--- .../{ => meta_reference}/generation.py | 5 +- .../{ => meta_reference}/inference.py | 16 +++---- .../{ => meta_reference}/model_parallel.py | 2 +- .../{ => meta_reference}/parallel_utils.py | 0 llama_toolchain/inference/ollama/__init__.py | 8 ++++ llama_toolchain/inference/ollama/config.py | 14 ++++++ .../inference/{ => ollama}/ollama.py | 15 +++--- llama_toolchain/inference/providers.py | 4 +- .../safety/meta_reference/__init__.py | 8 ++++ .../safety/{ => meta_reference}/config.py | 0 .../safety/{ => meta_reference}/safety.py | 2 +- .../{ => meta_reference}/shields/__init__.py | 0 .../{ => meta_reference}/shields/base.py | 0 .../shields/code_scanner.py | 0 .../shields/contrib/__init__.py | 5 ++ .../shields/contrib/third_party_shield.py | 5 +- .../shields/llama_guard.py | 0 .../shields/prompt_guard.py | 0 llama_toolchain/safety/providers.py | 4 +- 42 files changed, 123 insertions(+), 103 deletions(-) create mode 100644 llama_toolchain/agentic_system/meta_reference/__init__.py rename llama_toolchain/agentic_system/{ => meta_reference}/agent_instance.py (98%) rename llama_toolchain/agentic_system/{ => meta_reference}/agentic_system.py (97%) rename llama_toolchain/agentic_system/{ => meta_reference}/config.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/safety.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/system_prompt.py (100%) rename llama_toolchain/agentic_system/{tools/ipython_tool => meta_reference/tools}/__init__.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/tools/base.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/tools/builtin.py (100%) rename llama_toolchain/{safety/shields/contrib => agentic_system/meta_reference/tools/ipython_tool}/__init__.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/tools/ipython_tool/code_env_prefix.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/tools/ipython_tool/code_execution.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/tools/ipython_tool/matplotlib_custom_backend.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/tools/ipython_tool/utils.py (100%) rename llama_toolchain/agentic_system/{ => meta_reference}/tools/safety.py (95%) create mode 100644 llama_toolchain/agentic_system/tools/custom/__init__.py rename llama_toolchain/agentic_system/tools/{custom.py => custom/datatypes.py} (94%) rename llama_toolchain/agentic_system/tools/{ => custom}/execute.py (100%) delete mode 100644 llama_toolchain/inference/api_instance.py create mode 100644 llama_toolchain/inference/meta_reference/__init__.py rename llama_toolchain/inference/{api => meta_reference}/config.py (87%) rename llama_toolchain/inference/{ => meta_reference}/generation.py (98%) rename llama_toolchain/inference/{ => meta_reference}/inference.py (98%) rename llama_toolchain/inference/{ => meta_reference}/model_parallel.py (98%) rename llama_toolchain/inference/{ => meta_reference}/parallel_utils.py (100%) create mode 100644 llama_toolchain/inference/ollama/__init__.py create mode 100644 llama_toolchain/inference/ollama/config.py rename llama_toolchain/inference/{ => ollama}/ollama.py (99%) create mode 100644 llama_toolchain/safety/meta_reference/__init__.py rename llama_toolchain/safety/{ => meta_reference}/config.py (100%) rename llama_toolchain/safety/{ => meta_reference}/safety.py (98%) rename llama_toolchain/safety/{ => meta_reference}/shields/__init__.py (100%) rename llama_toolchain/safety/{ => meta_reference}/shields/base.py (100%) rename llama_toolchain/safety/{ => meta_reference}/shields/code_scanner.py (100%) create mode 100644 llama_toolchain/safety/meta_reference/shields/contrib/__init__.py rename llama_toolchain/safety/{ => meta_reference}/shields/contrib/third_party_shield.py (89%) rename llama_toolchain/safety/{ => meta_reference}/shields/llama_guard.py (100%) rename llama_toolchain/safety/{ => meta_reference}/shields/prompt_guard.py (100%) diff --git a/llama_toolchain/agentic_system/meta_reference/__init__.py b/llama_toolchain/agentic_system/meta_reference/__init__.py new file mode 100644 index 000000000..22b1f788a --- /dev/null +++ b/llama_toolchain/agentic_system/meta_reference/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .agentic_system import get_provider_impl # noqa +from .config import AgenticSystemConfig # noqa diff --git a/llama_toolchain/agentic_system/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py similarity index 98% rename from llama_toolchain/agentic_system/agent_instance.py rename to llama_toolchain/agentic_system/meta_reference/agent_instance.py index afb00655e..06a5bb3db 100644 --- a/llama_toolchain/agentic_system/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -10,12 +10,24 @@ import uuid from datetime import datetime from typing import AsyncGenerator, List, Optional -from llama_toolchain.inference.api import Inference -from llama_toolchain.safety.api import Safety +from llama_toolchain.agentic_system.api.datatypes import ( + AgenticSystemInstanceConfig, + AgenticSystemTurnResponseEvent, + AgenticSystemTurnResponseEventType, + AgenticSystemTurnResponseStepCompletePayload, + AgenticSystemTurnResponseStepProgressPayload, + AgenticSystemTurnResponseStepStartPayload, + AgenticSystemTurnResponseTurnCompletePayload, + AgenticSystemTurnResponseTurnStartPayload, + InferenceStep, + Session, + ShieldCallStep, + StepType, + ToolExecutionStep, + Turn, +) -from .api.endpoints import * # noqa - -from llama_toolchain.inference.api import ChatCompletionRequest +from llama_toolchain.inference.api import ChatCompletionRequest, Inference from llama_toolchain.inference.api.datatypes import ( Attachment, @@ -33,36 +45,16 @@ from llama_toolchain.inference.api.datatypes import ( ToolResponseMessage, URL, ) +from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api.datatypes import ( BuiltinShield, ShieldDefinition, ShieldResponse, ) - from termcolor import cprint +from llama_toolchain.agentic_system.api.endpoints import * # noqa -from .api.datatypes import ( - AgenticSystemInstanceConfig, - AgenticSystemTurnResponseEvent, - AgenticSystemTurnResponseEventType, - AgenticSystemTurnResponseStepCompletePayload, - AgenticSystemTurnResponseStepProgressPayload, - AgenticSystemTurnResponseStepStartPayload, - AgenticSystemTurnResponseTurnCompletePayload, - AgenticSystemTurnResponseTurnStartPayload, - InferenceStep, - Session, - ShieldCallStep, - StepType, - ToolExecutionStep, - Turn, -) -from .api.endpoints import ( - AgenticSystemTurnCreateRequest, - AgenticSystemTurnResponseStreamChunk, -) from .safety import SafetyException, ShieldRunnerMixin - from .system_prompt import get_agentic_prefix_messages from .tools.base import BaseTool from .tools.builtin import SingleMessageBuiltinTool diff --git a/llama_toolchain/agentic_system/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py similarity index 97% rename from llama_toolchain/agentic_system/agentic_system.py rename to llama_toolchain/agentic_system/meta_reference/agentic_system.py index 81a4a3337..5db8d6168 100644 --- a/llama_toolchain/agentic_system/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -5,25 +5,18 @@ # the root directory of this source tree. -from llama_toolchain.agentic_system.api import AgenticSystem - -from llama_toolchain.distribution.datatypes import Api, ProviderSpec -from llama_toolchain.inference.api import Inference -from llama_toolchain.safety.api import Safety - -from .config import AgenticSystemConfig -from .api.endpoints import * # noqa - import logging import os import uuid from typing import AsyncGenerator, Dict +from llama_toolchain.distribution.datatypes import Api, ProviderSpec +from llama_toolchain.inference.api import Inference from llama_toolchain.inference.api.datatypes import BuiltinTool - -from .agent_instance import AgentInstance - -from .api.endpoints import ( +from llama_toolchain.safety.api import Safety +from llama_toolchain.agentic_system.api.endpoints import * # noqa +from llama_toolchain.agentic_system.api import ( + AgenticSystem, AgenticSystemCreateRequest, AgenticSystemCreateResponse, AgenticSystemSessionCreateRequest, @@ -31,6 +24,10 @@ from .api.endpoints import ( AgenticSystemTurnCreateRequest, ) +from .agent_instance import AgentInstance + +from .config import AgenticSystemConfig + from .tools.builtin import ( BraveSearchTool, CodeInterpreterTool, diff --git a/llama_toolchain/agentic_system/config.py b/llama_toolchain/agentic_system/meta_reference/config.py similarity index 100% rename from llama_toolchain/agentic_system/config.py rename to llama_toolchain/agentic_system/meta_reference/config.py diff --git a/llama_toolchain/agentic_system/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py similarity index 100% rename from llama_toolchain/agentic_system/safety.py rename to llama_toolchain/agentic_system/meta_reference/safety.py diff --git a/llama_toolchain/agentic_system/system_prompt.py b/llama_toolchain/agentic_system/meta_reference/system_prompt.py similarity index 100% rename from llama_toolchain/agentic_system/system_prompt.py rename to llama_toolchain/agentic_system/meta_reference/system_prompt.py diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/__init__.py b/llama_toolchain/agentic_system/meta_reference/tools/__init__.py similarity index 100% rename from llama_toolchain/agentic_system/tools/ipython_tool/__init__.py rename to llama_toolchain/agentic_system/meta_reference/tools/__init__.py diff --git a/llama_toolchain/agentic_system/tools/base.py b/llama_toolchain/agentic_system/meta_reference/tools/base.py similarity index 100% rename from llama_toolchain/agentic_system/tools/base.py rename to llama_toolchain/agentic_system/meta_reference/tools/base.py diff --git a/llama_toolchain/agentic_system/tools/builtin.py b/llama_toolchain/agentic_system/meta_reference/tools/builtin.py similarity index 100% rename from llama_toolchain/agentic_system/tools/builtin.py rename to llama_toolchain/agentic_system/meta_reference/tools/builtin.py diff --git a/llama_toolchain/safety/shields/contrib/__init__.py b/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/__init__.py similarity index 100% rename from llama_toolchain/safety/shields/contrib/__init__.py rename to llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/__init__.py diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py b/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_env_prefix.py similarity index 100% rename from llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py rename to llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_env_prefix.py diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py b/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_execution.py similarity index 100% rename from llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py rename to llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_execution.py diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py b/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py similarity index 100% rename from llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py rename to llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/utils.py b/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/utils.py similarity index 100% rename from llama_toolchain/agentic_system/tools/ipython_tool/utils.py rename to llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/utils.py diff --git a/llama_toolchain/agentic_system/tools/safety.py b/llama_toolchain/agentic_system/meta_reference/tools/safety.py similarity index 95% rename from llama_toolchain/agentic_system/tools/safety.py rename to llama_toolchain/agentic_system/meta_reference/tools/safety.py index da0abe10a..aab67801d 100644 --- a/llama_toolchain/agentic_system/tools/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/tools/safety.py @@ -6,7 +6,7 @@ from typing import List -from llama_toolchain.agentic_system.safety import ShieldRunnerMixin +from llama_toolchain.agentic_system.meta_reference.safety import ShieldRunnerMixin from llama_toolchain.inference.api import Message from llama_toolchain.safety.api.datatypes import ShieldDefinition diff --git a/llama_toolchain/agentic_system/providers.py b/llama_toolchain/agentic_system/providers.py index a1521fa46..4f8055f16 100644 --- a/llama_toolchain/agentic_system/providers.py +++ b/llama_toolchain/agentic_system/providers.py @@ -19,8 +19,8 @@ def available_agentic_system_providers() -> List[ProviderSpec]: "torch", "transformers", ], - module="llama_toolchain.agentic_system.agentic_system", - config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig", + module="llama_toolchain.agentic_system.meta_reference", + config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig", api_dependencies=[ Api.inference, Api.safety, diff --git a/llama_toolchain/agentic_system/tools/custom/__init__.py b/llama_toolchain/agentic_system/tools/custom/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/custom/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/agentic_system/tools/custom.py b/llama_toolchain/agentic_system/tools/custom/datatypes.py similarity index 94% rename from llama_toolchain/agentic_system/tools/custom.py rename to llama_toolchain/agentic_system/tools/custom/datatypes.py index 35e3dd57d..ee46114e8 100644 --- a/llama_toolchain/agentic_system/tools/custom.py +++ b/llama_toolchain/agentic_system/tools/custom/datatypes.py @@ -12,7 +12,10 @@ from typing import Dict, List from llama_models.llama3_1.api.datatypes import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403 -from .builtin import interpret_content_as_attachment +# TODO: this is symptomatic of us needing to pull more tooling related utilities +from llama_toolchain.agentic_system.meta_reference.tools.builtin import ( + interpret_content_as_attachment, +) class CustomTool: diff --git a/llama_toolchain/agentic_system/tools/execute.py b/llama_toolchain/agentic_system/tools/custom/execute.py similarity index 100% rename from llama_toolchain/agentic_system/tools/execute.py rename to llama_toolchain/agentic_system/tools/custom/execute.py diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index 293d98944..299c5f93b 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -17,10 +17,15 @@ from llama_toolchain.agentic_system.api import ( ) from llama_toolchain.agentic_system.client import AgenticSystemClient -from llama_toolchain.agentic_system.tools.execute import execute_with_custom_tools +from llama_toolchain.agentic_system.tools.custom.execute import ( + execute_with_custom_tools, +) from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition +# TODO: this should move back to the llama-agentic-system repo + + class AgenticSystemClientWrapper: def __init__(self, api, system_id, custom_tools): diff --git a/llama_toolchain/inference/api_instance.py b/llama_toolchain/inference/api_instance.py deleted file mode 100644 index 560b99868..000000000 --- a/llama_toolchain/inference/api_instance.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -# from .api.config import ImplType, InferenceConfig - - -# async def get_inference_api_instance(config: InferenceConfig): -# if config.impl_config.impl_type == ImplType.inline.value: -# from .inference import InferenceImpl - -# return InferenceImpl(config.impl_config) -# elif config.impl_config.impl_type == ImplType.ollama.value: -# from .ollama import OllamaInference - -# return OllamaInference(config.impl_config) - -# from .client import InferenceClient - -# return InferenceClient(config.impl_config.url) diff --git a/llama_toolchain/inference/meta_reference/__init__.py b/llama_toolchain/inference/meta_reference/__init__.py new file mode 100644 index 000000000..87a08816e --- /dev/null +++ b/llama_toolchain/inference/meta_reference/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import MetaReferenceImplConfig # noqa +from .inference import get_provider_impl # noqa diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/meta_reference/config.py similarity index 87% rename from llama_toolchain/inference/api/config.py rename to llama_toolchain/inference/meta_reference/config.py index 2f01f90db..0f5bf8eb4 100644 --- a/llama_toolchain/inference/api/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type from typing_extensions import Annotated -from .datatypes import QuantizationConfig +from llama_toolchain.inference.api import QuantizationConfig @json_schema_type @@ -63,9 +63,3 @@ class MetaReferenceImplConfig(BaseModel): torch_seed: Optional[int] = None max_seq_len: int max_batch_size: int = 1 - - -@json_schema_type -class OllamaImplConfig(BaseModel): - model: str = Field(..., description="The name of the model in ollama catalog") - url: str = Field(..., description="The URL for the ollama server") diff --git a/llama_toolchain/inference/generation.py b/llama_toolchain/inference/meta_reference/generation.py similarity index 98% rename from llama_toolchain/inference/generation.py rename to llama_toolchain/inference/meta_reference/generation.py index be3b6967d..70580995d 100644 --- a/llama_toolchain/inference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -29,8 +29,9 @@ from llama_models.llama3_1.api.model import Transformer from llama_models.llama3_1.api.tokenizer import Tokenizer from termcolor import cprint -from .api.config import CheckpointType, MetaReferenceImplConfig -from .api.datatypes import QuantizationType +from llama_toolchain.inference.api import QuantizationType + +from .config import CheckpointType, MetaReferenceImplConfig @dataclass diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/meta_reference/inference.py similarity index 98% rename from llama_toolchain/inference/inference.py rename to llama_toolchain/inference/meta_reference/inference.py index 194a0a882..b41cb4acb 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -12,20 +12,18 @@ from llama_models.llama3_1.api.datatypes import StopReason from llama_models.sku_list import resolve_model from llama_toolchain.distribution.datatypes import Api, ProviderSpec - -from .api.config import MetaReferenceImplConfig -from .api.datatypes import ( +from llama_toolchain.inference.api import ( + ChatCompletionRequest, + ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + Inference, ToolCallDelta, ToolCallParseStatus, ) -from .api.endpoints import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, - Inference, -) + +from .config import MetaReferenceImplConfig from .model_parallel import LlamaModelParallelGenerator diff --git a/llama_toolchain/inference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py similarity index 98% rename from llama_toolchain/inference/model_parallel.py rename to llama_toolchain/inference/meta_reference/model_parallel.py index 8426f7890..58fbd2177 100644 --- a/llama_toolchain/inference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -13,7 +13,7 @@ from llama_models.llama3_1.api.chat_format import ChatFormat from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.tokenizer import Tokenizer -from .api.config import MetaReferenceImplConfig +from .config import MetaReferenceImplConfig from .generation import Llama from .parallel_utils import ModelParallelProcessGroup diff --git a/llama_toolchain/inference/parallel_utils.py b/llama_toolchain/inference/meta_reference/parallel_utils.py similarity index 100% rename from llama_toolchain/inference/parallel_utils.py rename to llama_toolchain/inference/meta_reference/parallel_utils.py diff --git a/llama_toolchain/inference/ollama/__init__.py b/llama_toolchain/inference/ollama/__init__.py new file mode 100644 index 000000000..40d79618a --- /dev/null +++ b/llama_toolchain/inference/ollama/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import OllamaImplConfig # noqa +from .ollama import get_provider_impl # noqa diff --git a/llama_toolchain/inference/ollama/config.py b/llama_toolchain/inference/ollama/config.py new file mode 100644 index 000000000..11d47806c --- /dev/null +++ b/llama_toolchain/inference/ollama/config.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field +from strong_typing.schema import json_schema_type + + +@json_schema_type +class OllamaImplConfig(BaseModel): + model: str = Field(..., description="The name of the model in ollama catalog") + url: str = Field(..., description="The URL for the ollama server") diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama/ollama.py similarity index 99% rename from llama_toolchain/inference/ollama.py rename to llama_toolchain/inference/ollama/ollama.py index 9526a8665..3afd1326b 100644 --- a/llama_toolchain/inference/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -22,21 +22,20 @@ from llama_models.sku_list import resolve_model from ollama import AsyncClient -from .api.config import OllamaImplConfig -from .api.datatypes import ( - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ToolCallDelta, - ToolCallParseStatus, -) -from .api.endpoints import ( +from llama_toolchain.inference.api import ( ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionRequest, Inference, + ToolCallDelta, + ToolCallParseStatus, ) +from .config import OllamaImplConfig + # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models OLLAMA_SUPPORTED_SKUS = { diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index a12defafa..80428c069 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -18,8 +18,8 @@ def available_inference_providers() -> List[ProviderSpec]: "torch", "zmq", ], - module="llama_toolchain.inference.inference", - config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig", + module="llama_toolchain.inference.meta_reference", + config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", ), InlineProviderSpec( api=Api.inference, diff --git a/llama_toolchain/safety/meta_reference/__init__.py b/llama_toolchain/safety/meta_reference/__init__.py new file mode 100644 index 000000000..f874f3dad --- /dev/null +++ b/llama_toolchain/safety/meta_reference/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import SafetyConfig # noqa +from .safety import get_provider_impl # noqa diff --git a/llama_toolchain/safety/config.py b/llama_toolchain/safety/meta_reference/config.py similarity index 100% rename from llama_toolchain/safety/config.py rename to llama_toolchain/safety/meta_reference/config.py diff --git a/llama_toolchain/safety/safety.py b/llama_toolchain/safety/meta_reference/safety.py similarity index 98% rename from llama_toolchain/safety/safety.py rename to llama_toolchain/safety/meta_reference/safety.py index 3f1c7698c..93b986c12 100644 --- a/llama_toolchain/safety/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -11,7 +11,7 @@ from typing import Dict from llama_toolchain.distribution.datatypes import Api, ProviderSpec from .config import SafetyConfig -from .api.endpoints import * # noqa +from llama_toolchain.safety.api import * # noqa from .shields import ( CodeScannerShield, InjectionShield, diff --git a/llama_toolchain/safety/shields/__init__.py b/llama_toolchain/safety/meta_reference/shields/__init__.py similarity index 100% rename from llama_toolchain/safety/shields/__init__.py rename to llama_toolchain/safety/meta_reference/shields/__init__.py diff --git a/llama_toolchain/safety/shields/base.py b/llama_toolchain/safety/meta_reference/shields/base.py similarity index 100% rename from llama_toolchain/safety/shields/base.py rename to llama_toolchain/safety/meta_reference/shields/base.py diff --git a/llama_toolchain/safety/shields/code_scanner.py b/llama_toolchain/safety/meta_reference/shields/code_scanner.py similarity index 100% rename from llama_toolchain/safety/shields/code_scanner.py rename to llama_toolchain/safety/meta_reference/shields/code_scanner.py diff --git a/llama_toolchain/safety/meta_reference/shields/contrib/__init__.py b/llama_toolchain/safety/meta_reference/shields/contrib/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/safety/meta_reference/shields/contrib/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/safety/shields/contrib/third_party_shield.py b/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py similarity index 89% rename from llama_toolchain/safety/shields/contrib/third_party_shield.py rename to llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py index da5282cbe..789fa5f07 100644 --- a/llama_toolchain/safety/shields/contrib/third_party_shield.py +++ b/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py @@ -4,14 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import sys from typing import List from llama_models.llama3_1.api.datatypes import Message -parent_dir = "../.." -sys.path.append(parent_dir) -from llama_toolchain.safety.shields.base import ( +from llama_toolchain.safety.meta_reference.shields.base import ( OnViolationAction, ShieldBase, ShieldResponse, diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/meta_reference/shields/llama_guard.py similarity index 100% rename from llama_toolchain/safety/shields/llama_guard.py rename to llama_toolchain/safety/meta_reference/shields/llama_guard.py diff --git a/llama_toolchain/safety/shields/prompt_guard.py b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py similarity index 100% rename from llama_toolchain/safety/shields/prompt_guard.py rename to llama_toolchain/safety/meta_reference/shields/prompt_guard.py diff --git a/llama_toolchain/safety/providers.py b/llama_toolchain/safety/providers.py index 4a88c8e28..f8e2e0a86 100644 --- a/llama_toolchain/safety/providers.py +++ b/llama_toolchain/safety/providers.py @@ -19,7 +19,7 @@ def available_safety_providers() -> List[ProviderSpec]: "torch", "transformers", ], - module="llama_toolchain.safety.safety", - config_class="llama_toolchain.safety.config.SafetyConfig", + module="llama_toolchain.safety.meta_reference", + config_class="llama_toolchain.safety.meta_reference.SafetyConfig", ), ]