From 3230af49105a2009168f77a7a2ff2cb8bf0f78c5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 26 Aug 2024 12:55:28 -0700 Subject: [PATCH] combine datatypes.py and endpoints.py into api.py --- .../agentic_system/api/__init__.py | 3 +- .../api/{datatypes.py => api.py} | 97 ++++++++++++++++- .../agentic_system/api/endpoints.py | 103 ------------------ llama_toolchain/dataset/api/__init__.py | 3 +- .../dataset/api/{endpoints.py => api.py} | 25 ++++- llama_toolchain/dataset/api/datatypes.py | 34 ------ llama_toolchain/evaluations/api/__init__.py | 3 +- .../evaluations/api/{endpoints.py => api.py} | 26 ++++- llama_toolchain/evaluations/api/datatypes.py | 33 ------ llama_toolchain/inference/api/__init__.py | 3 +- .../inference/api/{endpoints.py => api.py} | 70 +++++++++++- llama_toolchain/inference/api/datatypes.py | 72 ------------ llama_toolchain/memory/api/__init__.py | 3 +- .../memory/api/{endpoints.py => api.py} | 10 +- llama_toolchain/memory/api/datatypes.py | 5 - llama_toolchain/observability/api/__init__.py | 3 +- .../api/{endpoints.py => api.py} | 73 ++++++++++++- .../observability/api/datatypes.py | 80 -------------- llama_toolchain/post_training/api/__init__.py | 3 +- .../api/{endpoints.py => api.py} | 84 +++++++++++++- .../post_training/api/datatypes.py | 94 ---------------- .../reward_scoring/api/__init__.py | 3 +- .../api/{endpoints.py => api.py} | 25 ++++- .../reward_scoring/api/datatypes.py | 31 ------ llama_toolchain/safety/api/__init__.py | 3 +- .../safety/api/{datatypes.py => api.py} | 26 ++++- llama_toolchain/safety/api/endpoints.py | 32 ------ .../synthetic_data_generation/api/__init__.py | 3 +- .../api/{endpoints.py => api.py} | 14 ++- .../api/datatypes.py | 18 --- 30 files changed, 436 insertions(+), 546 deletions(-) rename llama_toolchain/agentic_system/api/{datatypes.py => api.py} (78%) delete mode 100644 llama_toolchain/agentic_system/api/endpoints.py rename llama_toolchain/dataset/api/{endpoints.py => api.py} (58%) delete mode 100644 llama_toolchain/dataset/api/datatypes.py rename llama_toolchain/evaluations/api/{endpoints.py => api.py} (87%) delete mode 100644 llama_toolchain/evaluations/api/datatypes.py rename llama_toolchain/inference/api/{endpoints.py => api.py} (70%) delete mode 100644 llama_toolchain/inference/api/datatypes.py rename llama_toolchain/memory/api/{endpoints.py => api.py} (93%) delete mode 100644 llama_toolchain/memory/api/datatypes.py rename llama_toolchain/observability/api/{endpoints.py => api.py} (70%) delete mode 100644 llama_toolchain/observability/api/datatypes.py rename llama_toolchain/post_training/api/{endpoints.py => api.py} (69%) delete mode 100644 llama_toolchain/post_training/api/datatypes.py rename llama_toolchain/reward_scoring/api/{endpoints.py => api.py} (63%) delete mode 100644 llama_toolchain/reward_scoring/api/datatypes.py rename llama_toolchain/safety/api/{datatypes.py => api.py} (75%) delete mode 100644 llama_toolchain/safety/api/endpoints.py rename llama_toolchain/synthetic_data_generation/api/{endpoints.py => api.py} (84%) delete mode 100644 llama_toolchain/synthetic_data_generation/api/datatypes.py diff --git a/llama_toolchain/agentic_system/api/__init__.py b/llama_toolchain/agentic_system/api/__init__.py index 4cefa053f..a7e55ba91 100644 --- a/llama_toolchain/agentic_system/api/__init__.py +++ b/llama_toolchain/agentic_system/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa -from .endpoints import * # noqa +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/api.py similarity index 78% rename from llama_toolchain/agentic_system/api/datatypes.py rename to llama_toolchain/agentic_system/api/api.py index c22d71635..056d5ab67 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -6,9 +6,9 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type +from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated @@ -320,3 +320,96 @@ class AgenticSystemTurnResponseEvent(BaseModel): ], Field(discriminator="event_type"), ] + + +@json_schema_type +class AgenticSystemCreateResponse(BaseModel): + agent_id: str + + +@json_schema_type +class AgenticSystemSessionCreateResponse(BaseModel): + session_id: str + + +@json_schema_type +class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn): + agent_id: str + session_id: str + + # TODO: figure out how we can simplify this and make why + # ToolResponseMessage needs to be here (it is function call + # execution from outside the system) + messages: List[ + Union[ + UserMessage, + ToolResponseMessage, + ] + ] + attachments: Optional[List[Attachment]] = None + + stream: Optional[bool] = False + + +@json_schema_type( + schema={"description": "Server side event (SSE) stream of these events"} +) +class AgenticSystemTurnResponseStreamChunk(BaseModel): + event: AgenticSystemTurnResponseEvent + + +@json_schema_type +class AgenticSystemStepResponse(BaseModel): + step: Step + + +class AgenticSystem(Protocol): + @webmethod(route="/agentic_system/create") + async def create_agentic_system( + self, + agent_config: AgentConfig, + ) -> AgenticSystemCreateResponse: ... + + @webmethod(route="/agentic_system/turn/create") + async def create_agentic_system_turn( + self, + request: AgenticSystemTurnCreateRequest, + ) -> AgenticSystemTurnResponseStreamChunk: ... + + @webmethod(route="/agentic_system/turn/get") + async def get_agentic_system_turn( + self, + agent_id: str, + turn_id: str, + ) -> Turn: ... + + @webmethod(route="/agentic_system/step/get") + async def get_agentic_system_step( + self, agent_id: str, turn_id: str, step_id: str + ) -> AgenticSystemStepResponse: ... + + @webmethod(route="/agentic_system/session/create") + async def create_agentic_system_session( + self, + agent_id: str, + session_name: str, + ) -> AgenticSystemSessionCreateResponse: ... + + @webmethod(route="/agentic_system/session/get") + async def get_agentic_system_session( + self, + agent_id: str, + session_id: str, + turn_ids: Optional[List[str]] = None, + ) -> Session: ... + + @webmethod(route="/agentic_system/session/delete") + async def delete_agentic_system_session( + self, agent_id: str, session_id: str + ) -> None: ... + + @webmethod(route="/agentic_system/delete") + async def delete_agentic_system( + self, + agent_id: str, + ) -> None: ... diff --git a/llama_toolchain/agentic_system/api/endpoints.py b/llama_toolchain/agentic_system/api/endpoints.py deleted file mode 100644 index 663edeb8d..000000000 --- a/llama_toolchain/agentic_system/api/endpoints.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .datatypes import * # noqa: F403 -from typing import Protocol - -from llama_models.schema_utils import json_schema_type, webmethod - - -@json_schema_type -class AgenticSystemCreateResponse(BaseModel): - agent_id: str - - -@json_schema_type -class AgenticSystemSessionCreateResponse(BaseModel): - session_id: str - - -@json_schema_type -class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn): - agent_id: str - session_id: str - - # TODO: figure out how we can simplify this and make why - # ToolResponseMessage needs to be here (it is function call - # execution from outside the system) - messages: List[ - Union[ - UserMessage, - ToolResponseMessage, - ] - ] - attachments: Optional[List[Attachment]] = None - - stream: Optional[bool] = False - - -@json_schema_type( - schema={"description": "Server side event (SSE) stream of these events"} -) -class AgenticSystemTurnResponseStreamChunk(BaseModel): - event: AgenticSystemTurnResponseEvent - - -@json_schema_type -class AgenticSystemStepResponse(BaseModel): - step: Step - - -class AgenticSystem(Protocol): - @webmethod(route="/agentic_system/create") - async def create_agentic_system( - self, - agent_config: AgentConfig, - ) -> AgenticSystemCreateResponse: ... - - @webmethod(route="/agentic_system/turn/create") - async def create_agentic_system_turn( - self, - request: AgenticSystemTurnCreateRequest, - ) -> AgenticSystemTurnResponseStreamChunk: ... - - @webmethod(route="/agentic_system/turn/get") - async def get_agentic_system_turn( - self, - agent_id: str, - turn_id: str, - ) -> Turn: ... - - @webmethod(route="/agentic_system/step/get") - async def get_agentic_system_step( - self, agent_id: str, turn_id: str, step_id: str - ) -> AgenticSystemStepResponse: ... - - @webmethod(route="/agentic_system/session/create") - async def create_agentic_system_session( - self, - agent_id: str, - session_name: str, - ) -> AgenticSystemSessionCreateResponse: ... - - @webmethod(route="/agentic_system/session/get") - async def get_agentic_system_session( - self, - agent_id: str, - session_id: str, - turn_ids: Optional[List[str]] = None, - ) -> Session: ... - - @webmethod(route="/agentic_system/session/delete") - async def delete_agentic_system_session( - self, agent_id: str, session_id: str - ) -> None: ... - - @webmethod(route="/agentic_system/delete") - async def delete_agentic_system( - self, - agent_id: str, - ) -> None: ... diff --git a/llama_toolchain/dataset/api/__init__.py b/llama_toolchain/dataset/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/dataset/api/__init__.py +++ b/llama_toolchain/dataset/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/dataset/api/endpoints.py b/llama_toolchain/dataset/api/api.py similarity index 58% rename from llama_toolchain/dataset/api/endpoints.py rename to llama_toolchain/dataset/api/api.py index 6a88f4b7a..c22fc01b0 100644 --- a/llama_toolchain/dataset/api/endpoints.py +++ b/llama_toolchain/dataset/api/api.py @@ -4,13 +4,34 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Protocol +from enum import Enum +from typing import Any, Dict, Optional, Protocol + +from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from .datatypes import * # noqa: F403 + +@json_schema_type +class TrainEvalDatasetColumnType(Enum): + dialog = "dialog" + text = "text" + media = "media" + number = "number" + json = "json" + + +@json_schema_type +class TrainEvalDataset(BaseModel): + """Dataset to be used for training or evaluating language models.""" + + # TODO(ashwin): figure out if we need to add an enum for a "dataset type" + + columns: Dict[str, TrainEvalDatasetColumnType] + content_url: URL + metadata: Optional[Dict[str, Any]] = None @json_schema_type diff --git a/llama_toolchain/dataset/api/datatypes.py b/llama_toolchain/dataset/api/datatypes.py deleted file mode 100644 index 32109b37c..000000000 --- a/llama_toolchain/dataset/api/datatypes.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import Any, Dict, Optional - -from llama_models.llama3.api.datatypes import URL - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - - -@json_schema_type -class TrainEvalDatasetColumnType(Enum): - dialog = "dialog" - text = "text" - media = "media" - number = "number" - json = "json" - - -@json_schema_type -class TrainEvalDataset(BaseModel): - """Dataset to be used for training or evaluating language models.""" - - # TODO(ashwin): figure out if we need to add an enum for a "dataset type" - - columns: Dict[str, TrainEvalDatasetColumnType] - content_url: URL - metadata: Optional[Dict[str, Any]] = None diff --git a/llama_toolchain/evaluations/api/__init__.py b/llama_toolchain/evaluations/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/evaluations/api/__init__.py +++ b/llama_toolchain/evaluations/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/evaluations/api/endpoints.py b/llama_toolchain/evaluations/api/api.py similarity index 87% rename from llama_toolchain/evaluations/api/endpoints.py rename to llama_toolchain/evaluations/api/api.py index 25fb570f7..3e03fe12e 100644 --- a/llama_toolchain/evaluations/api/endpoints.py +++ b/llama_toolchain/evaluations/api/api.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import List, Protocol from llama_models.schema_utils import webmethod @@ -11,11 +12,34 @@ from llama_models.schema_utils import webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from .datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 +class TextGenerationMetric(Enum): + perplexity = "perplexity" + rouge = "rouge" + bleu = "bleu" + + +class QuestionAnsweringMetric(Enum): + em = "em" + f1 = "f1" + + +class SummarizationMetric(Enum): + rouge = "rouge" + bleu = "bleu" + + +class EvaluationJob(BaseModel): + job_uuid: str + + +class EvaluationJobLogStream(BaseModel): + job_uuid: str + + class EvaluateTaskRequestCommon(BaseModel): job_uuid: str dataset: TrainEvalDataset diff --git a/llama_toolchain/evaluations/api/datatypes.py b/llama_toolchain/evaluations/api/datatypes.py deleted file mode 100644 index 0ba284e9d..000000000 --- a/llama_toolchain/evaluations/api/datatypes.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum - -from pydantic import BaseModel - - -class TextGenerationMetric(Enum): - perplexity = "perplexity" - rouge = "rouge" - bleu = "bleu" - - -class QuestionAnsweringMetric(Enum): - em = "em" - f1 = "f1" - - -class SummarizationMetric(Enum): - rouge = "rouge" - bleu = "bleu" - - -class EvaluationJob(BaseModel): - job_uuid: str - - -class EvaluationJobLogStream(BaseModel): - job_uuid: str diff --git a/llama_toolchain/inference/api/__init__.py b/llama_toolchain/inference/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/inference/api/__init__.py +++ b/llama_toolchain/inference/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/api.py similarity index 70% rename from llama_toolchain/inference/api/endpoints.py rename to llama_toolchain/inference/api/api.py index f09c0e3f8..cf72ef5fd 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/api.py @@ -4,13 +4,73 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F403 -from typing import Optional, Protocol +from enum import Enum -from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat +from typing import List, Literal, Optional, Protocol, Union -# this dependency is annoying and we need a forked up version anyway -from llama_models.schema_utils import webmethod +from llama_models.schema_utils import json_schema_type, webmethod + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from llama_models.llama3.api.datatypes import * # noqa: F403 + + +class LogProbConfig(BaseModel): + top_k: Optional[int] = 0 + + +@json_schema_type +class QuantizationType(Enum): + bf16 = "bf16" + fp8 = "fp8" + + +@json_schema_type +class Fp8QuantizationConfig(BaseModel): + type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value + + +@json_schema_type +class Bf16QuantizationConfig(BaseModel): + type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value + + +QuantizationConfig = Annotated[ + Union[Bf16QuantizationConfig, Fp8QuantizationConfig], + Field(discriminator="type"), +] + + +@json_schema_type +class ChatCompletionResponseEventType(Enum): + start = "start" + complete = "complete" + progress = "progress" + + +@json_schema_type +class ToolCallParseStatus(Enum): + started = "started" + in_progress = "in_progress" + failure = "failure" + success = "success" + + +@json_schema_type +class ToolCallDelta(BaseModel): + content: Union[str, ToolCall] + parse_status: ToolCallParseStatus + + +@json_schema_type +class ChatCompletionResponseEvent(BaseModel): + """Chat completion response event.""" + + event_type: ChatCompletionResponseEventType + delta: Union[str, ToolCallDelta] + logprobs: Optional[List[TokenLogProbs]] = None + stop_reason: Optional[StopReason] = None @json_schema_type diff --git a/llama_toolchain/inference/api/datatypes.py b/llama_toolchain/inference/api/datatypes.py deleted file mode 100644 index 571ecc3ea..000000000 --- a/llama_toolchain/inference/api/datatypes.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import List, Literal, Optional, Union - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel, Field -from typing_extensions import Annotated - -from llama_models.llama3.api.datatypes import * # noqa: F403 - - -class LogProbConfig(BaseModel): - top_k: Optional[int] = 0 - - -@json_schema_type -class QuantizationType(Enum): - bf16 = "bf16" - fp8 = "fp8" - - -@json_schema_type -class Fp8QuantizationConfig(BaseModel): - type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value - - -@json_schema_type -class Bf16QuantizationConfig(BaseModel): - type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value - - -QuantizationConfig = Annotated[ - Union[Bf16QuantizationConfig, Fp8QuantizationConfig], - Field(discriminator="type"), -] - - -@json_schema_type -class ChatCompletionResponseEventType(Enum): - start = "start" - complete = "complete" - progress = "progress" - - -@json_schema_type -class ToolCallParseStatus(Enum): - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - -@json_schema_type -class ToolCallDelta(BaseModel): - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - -@json_schema_type -class ChatCompletionResponseEvent(BaseModel): - """Chat completion response event.""" - - event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] - logprobs: Optional[List[TokenLogProbs]] = None - stop_reason: Optional[StopReason] = None diff --git a/llama_toolchain/memory/api/__init__.py b/llama_toolchain/memory/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/memory/api/__init__.py +++ b/llama_toolchain/memory/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/api.py similarity index 93% rename from llama_toolchain/memory/api/endpoints.py rename to llama_toolchain/memory/api/api.py index 9299872e3..9b86d29a2 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/api.py @@ -3,17 +3,21 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. from typing import List, Optional, Protocol +from llama_models.schema_utils import json_schema_type, webmethod + from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.schema_utils import webmethod -from .datatypes import * # noqa: F403 - @json_schema_type class MemoryBankDocument(BaseModel): diff --git a/llama_toolchain/memory/api/datatypes.py b/llama_toolchain/memory/api/datatypes.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_toolchain/memory/api/datatypes.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_toolchain/observability/api/__init__.py b/llama_toolchain/observability/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/observability/api/__init__.py +++ b/llama_toolchain/observability/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/observability/api/endpoints.py b/llama_toolchain/observability/api/api.py similarity index 70% rename from llama_toolchain/observability/api/endpoints.py rename to llama_toolchain/observability/api/api.py index 3f993ac2d..86a5cc703 100644 --- a/llama_toolchain/observability/api/endpoints.py +++ b/llama_toolchain/observability/api/api.py @@ -5,12 +5,79 @@ # the root directory of this source tree. from datetime import datetime -from typing import Any, Dict, List, Optional, Protocol +from enum import Enum + +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod + from pydantic import BaseModel -from llama_models.llama3.api.datatypes import * # noqa: F403 -from .datatypes import * # noqa: F403 + + +@json_schema_type +class ExperimentStatus(Enum): + NOT_STARTED = "not_started" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +@json_schema_type +class Experiment(BaseModel): + id: str + name: str + status: ExperimentStatus + created_at: datetime + updated_at: datetime + metadata: Dict[str, Any] + + +@json_schema_type +class Run(BaseModel): + id: str + experiment_id: str + status: str + started_at: datetime + ended_at: Optional[datetime] + metadata: Dict[str, Any] + + +@json_schema_type +class Metric(BaseModel): + name: str + value: Union[float, int, str, bool] + timestamp: datetime + run_id: str + + +@json_schema_type +class Log(BaseModel): + message: str + level: str + timestamp: datetime + additional_info: Dict[str, Any] + + +@json_schema_type +class ArtifactType(Enum): + MODEL = "model" + DATASET = "dataset" + CHECKPOINT = "checkpoint" + PLOT = "plot" + METRIC = "metric" + CONFIG = "config" + CODE = "code" + OTHER = "other" + + +@json_schema_type +class Artifact(BaseModel): + id: str + name: str + type: ArtifactType + size: int + created_at: datetime + metadata: Dict[str, Any] @json_schema_type diff --git a/llama_toolchain/observability/api/datatypes.py b/llama_toolchain/observability/api/datatypes.py deleted file mode 100644 index 42f95b64c..000000000 --- a/llama_toolchain/observability/api/datatypes.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from datetime import datetime -from enum import Enum - -from typing import Any, Dict, Optional, Union - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - - -@json_schema_type -class ExperimentStatus(Enum): - NOT_STARTED = "not_started" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" - - -@json_schema_type -class Experiment(BaseModel): - id: str - name: str - status: ExperimentStatus - created_at: datetime - updated_at: datetime - metadata: Dict[str, Any] - - -@json_schema_type -class Run(BaseModel): - id: str - experiment_id: str - status: str - started_at: datetime - ended_at: Optional[datetime] - metadata: Dict[str, Any] - - -@json_schema_type -class Metric(BaseModel): - name: str - value: Union[float, int, str, bool] - timestamp: datetime - run_id: str - - -@json_schema_type -class Log(BaseModel): - message: str - level: str - timestamp: datetime - additional_info: Dict[str, Any] - - -@json_schema_type -class ArtifactType(Enum): - MODEL = "model" - DATASET = "dataset" - CHECKPOINT = "checkpoint" - PLOT = "plot" - METRIC = "metric" - CONFIG = "config" - CODE = "code" - OTHER = "other" - - -@json_schema_type -class Artifact(BaseModel): - id: str - name: str - type: ArtifactType - size: int - created_at: datetime - metadata: Dict[str, Any] diff --git a/llama_toolchain/post_training/api/__init__.py b/llama_toolchain/post_training/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/post_training/api/__init__.py +++ b/llama_toolchain/post_training/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/post_training/api/endpoints.py b/llama_toolchain/post_training/api/api.py similarity index 69% rename from llama_toolchain/post_training/api/endpoints.py rename to llama_toolchain/post_training/api/api.py index f0536ee4c..ce7dcd65c 100644 --- a/llama_toolchain/post_training/api/endpoints.py +++ b/llama_toolchain/post_training/api/api.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from datetime import datetime +from enum import Enum from typing import Any, Dict, List, Optional, Protocol @@ -15,7 +16,88 @@ from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 -from .datatypes import * # noqa: F403 + + +class OptimizerType(Enum): + adam = "adam" + adamw = "adamw" + sgd = "sgd" + + +@json_schema_type +class OptimizerConfig(BaseModel): + optimizer_type: OptimizerType + lr: float + lr_min: float + weight_decay: float + + +@json_schema_type +class TrainingConfig(BaseModel): + n_epochs: int + batch_size: int + shuffle: bool + n_iters: int + + enable_activation_checkpointing: bool + memory_efficient_fsdp_wrap: bool + fsdp_cpu_offload: bool + + +@json_schema_type +class FinetuningAlgorithm(Enum): + full = "full" + lora = "lora" + qlora = "qlora" + dora = "dora" + + +@json_schema_type +class LoraFinetuningConfig(BaseModel): + lora_attn_modules: List[str] + apply_lora_to_mlp: bool + apply_lora_to_output: bool + rank: int + alpha: int + + +@json_schema_type +class QLoraFinetuningConfig(LoraFinetuningConfig): + pass + + +@json_schema_type +class DoraFinetuningConfig(LoraFinetuningConfig): + pass + + +@json_schema_type +class PostTrainingJobLogStream(BaseModel): + """Stream of logs from a finetuning job.""" + + job_uuid: str + log_lines: List[str] + + +@json_schema_type +class PostTrainingJobStatus(Enum): + running = "running" + completed = "completed" + failed = "failed" + scheduled = "scheduled" + + +@json_schema_type +class RLHFAlgorithm(Enum): + dpo = "dpo" + + +@json_schema_type +class DPOAlignmentConfig(BaseModel): + reward_scale: float + reward_clip: float + epsilon: float + gamma: float @json_schema_type diff --git a/llama_toolchain/post_training/api/datatypes.py b/llama_toolchain/post_training/api/datatypes.py deleted file mode 100644 index 45a259f03..000000000 --- a/llama_toolchain/post_training/api/datatypes.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import List - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - - -class OptimizerType(Enum): - adam = "adam" - adamw = "adamw" - sgd = "sgd" - - -@json_schema_type -class OptimizerConfig(BaseModel): - optimizer_type: OptimizerType - lr: float - lr_min: float - weight_decay: float - - -@json_schema_type -class TrainingConfig(BaseModel): - n_epochs: int - batch_size: int - shuffle: bool - n_iters: int - - enable_activation_checkpointing: bool - memory_efficient_fsdp_wrap: bool - fsdp_cpu_offload: bool - - -@json_schema_type -class FinetuningAlgorithm(Enum): - full = "full" - lora = "lora" - qlora = "qlora" - dora = "dora" - - -@json_schema_type -class LoraFinetuningConfig(BaseModel): - lora_attn_modules: List[str] - apply_lora_to_mlp: bool - apply_lora_to_output: bool - rank: int - alpha: int - - -@json_schema_type -class QLoraFinetuningConfig(LoraFinetuningConfig): - pass - - -@json_schema_type -class DoraFinetuningConfig(LoraFinetuningConfig): - pass - - -@json_schema_type -class PostTrainingJobLogStream(BaseModel): - """Stream of logs from a finetuning job.""" - - job_uuid: str - log_lines: List[str] - - -@json_schema_type -class PostTrainingJobStatus(Enum): - running = "running" - completed = "completed" - failed = "failed" - scheduled = "scheduled" - - -@json_schema_type -class RLHFAlgorithm(Enum): - dpo = "dpo" - - -@json_schema_type -class DPOAlignmentConfig(BaseModel): - reward_scale: float - reward_clip: float - epsilon: float - gamma: float diff --git a/llama_toolchain/reward_scoring/api/__init__.py b/llama_toolchain/reward_scoring/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/reward_scoring/api/__init__.py +++ b/llama_toolchain/reward_scoring/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/reward_scoring/api/endpoints.py b/llama_toolchain/reward_scoring/api/api.py similarity index 63% rename from llama_toolchain/reward_scoring/api/endpoints.py rename to llama_toolchain/reward_scoring/api/api.py index 657e7b325..c91931f09 100644 --- a/llama_toolchain/reward_scoring/api/endpoints.py +++ b/llama_toolchain/reward_scoring/api/api.py @@ -5,9 +5,30 @@ # the root directory of this source tree. from typing import List, Protocol, Union -from .datatypes import * # noqa: F403 -from llama_models.schema_utils import webmethod +from llama_models.schema_utils import json_schema_type, webmethod + +from pydantic import BaseModel + +from llama_models.llama3.api.datatypes import * # noqa: F403 + + +@json_schema_type +class ScoredMessage(BaseModel): + message: Message + score: float + + +@json_schema_type +class DialogGenerations(BaseModel): + dialog: List[Message] + sampled_generations: List[Message] + + +@json_schema_type +class ScoredDialogGenerations(BaseModel): + dialog: List[Message] + scored_generations: List[ScoredMessage] @json_schema_type diff --git a/llama_toolchain/reward_scoring/api/datatypes.py b/llama_toolchain/reward_scoring/api/datatypes.py deleted file mode 100644 index 2ce698d47..000000000 --- a/llama_toolchain/reward_scoring/api/datatypes.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - -from llama_models.llama3.api.datatypes import * # noqa: F403 - - -@json_schema_type -class ScoredMessage(BaseModel): - message: Message - score: float - - -@json_schema_type -class DialogGenerations(BaseModel): - dialog: List[Message] - sampled_generations: List[Message] - - -@json_schema_type -class ScoredDialogGenerations(BaseModel): - dialog: List[Message] - scored_generations: List[ScoredMessage] diff --git a/llama_toolchain/safety/api/__init__.py b/llama_toolchain/safety/api/__init__.py index 4cefa053f..a7e55ba91 100644 --- a/llama_toolchain/safety/api/__init__.py +++ b/llama_toolchain/safety/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa -from .endpoints import * # noqa +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/api.py similarity index 75% rename from llama_toolchain/safety/api/datatypes.py rename to llama_toolchain/safety/api/api.py index 5deecc2b3..96682d172 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/api.py @@ -5,13 +5,12 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, Optional, Union - -from llama_models.llama3.api.datatypes import ToolParamDefinition -from llama_models.schema_utils import json_schema_type +from typing import Dict, List, Optional, Protocol, Union +from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, validator +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.common.deployment_types import RestAPIExecutionConfig @@ -70,3 +69,22 @@ class ShieldResponse(BaseModel): except ValueError: return v return v + + +@json_schema_type +class RunShieldRequest(BaseModel): + messages: List[Message] + shields: List[ShieldDefinition] + + +@json_schema_type +class RunShieldResponse(BaseModel): + responses: List[ShieldResponse] + + +class Safety(Protocol): + @webmethod(route="/safety/run_shields") + async def run_shields( + self, + request: RunShieldRequest, + ) -> RunShieldResponse: ... diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py deleted file mode 100644 index a282a7968..000000000 --- a/llama_toolchain/safety/api/endpoints.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .datatypes import * # noqa: F403 -from typing import List, Protocol - -from llama_models.llama3.api.datatypes import Message - -# this dependency is annoying and we need a forked up version anyway -from llama_models.schema_utils import webmethod - - -@json_schema_type -class RunShieldRequest(BaseModel): - messages: List[Message] - shields: List[ShieldDefinition] - - -@json_schema_type -class RunShieldResponse(BaseModel): - responses: List[ShieldResponse] - - -class Safety(Protocol): - @webmethod(route="/safety/run_shields") - async def run_shields( - self, - request: RunShieldRequest, - ) -> RunShieldResponse: ... diff --git a/llama_toolchain/synthetic_data_generation/api/__init__.py b/llama_toolchain/synthetic_data_generation/api/__init__.py index 647bd4a5f..a7e55ba91 100644 --- a/llama_toolchain/synthetic_data_generation/api/__init__.py +++ b/llama_toolchain/synthetic_data_generation/api/__init__.py @@ -4,5 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datatypes import * # noqa: F401 F403 -from .endpoints import * # noqa: F401 F403 +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/synthetic_data_generation/api/endpoints.py b/llama_toolchain/synthetic_data_generation/api/api.py similarity index 84% rename from llama_toolchain/synthetic_data_generation/api/endpoints.py rename to llama_toolchain/synthetic_data_generation/api/api.py index d6b9c83d5..4d82553a3 100644 --- a/llama_toolchain/synthetic_data_generation/api/endpoints.py +++ b/llama_toolchain/synthetic_data_generation/api/api.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum + from typing import Any, Dict, List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod @@ -12,7 +14,17 @@ from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403 -from .datatypes import * # noqa: F403 + + +class FilteringFunction(Enum): + """The type of filtering function.""" + + none = "none" + random = "random" + top_k = "top_k" + top_p = "top_p" + top_k_top_p = "top_k_top_p" + sigmoid = "sigmoid" @json_schema_type diff --git a/llama_toolchain/synthetic_data_generation/api/datatypes.py b/llama_toolchain/synthetic_data_generation/api/datatypes.py deleted file mode 100644 index 1cef6653b..000000000 --- a/llama_toolchain/synthetic_data_generation/api/datatypes.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum - - -class FilteringFunction(Enum): - """The type of filtering function.""" - - none = "none" - random = "random" - top_k = "top_k" - top_p = "top_p" - top_k_top_p = "top_k_top_p" - sigmoid = "sigmoid"