mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
combine datatypes.py and endpoints.py into api.py
This commit is contained in:
parent
c1078a60e7
commit
3230af4910
30 changed files with 436 additions and 546 deletions
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa
|
|
||||||
|
|
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -320,3 +320,96 @@ class AgenticSystemTurnResponseEvent(BaseModel):
|
||||||
],
|
],
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemCreateResponse(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemSessionCreateResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
|
agent_id: str
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
# TODO: figure out how we can simplify this and make why
|
||||||
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
|
# execution from outside the system)
|
||||||
|
messages: List[
|
||||||
|
Union[
|
||||||
|
UserMessage,
|
||||||
|
ToolResponseMessage,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
attachments: Optional[List[Attachment]] = None
|
||||||
|
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type(
|
||||||
|
schema={"description": "Server side event (SSE) stream of these events"}
|
||||||
|
)
|
||||||
|
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
||||||
|
event: AgenticSystemTurnResponseEvent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgenticSystemStepResponse(BaseModel):
|
||||||
|
step: Step
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticSystem(Protocol):
|
||||||
|
@webmethod(route="/agentic_system/create")
|
||||||
|
async def create_agentic_system(
|
||||||
|
self,
|
||||||
|
agent_config: AgentConfig,
|
||||||
|
) -> AgenticSystemCreateResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/turn/create")
|
||||||
|
async def create_agentic_system_turn(
|
||||||
|
self,
|
||||||
|
request: AgenticSystemTurnCreateRequest,
|
||||||
|
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/turn/get")
|
||||||
|
async def get_agentic_system_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
) -> Turn: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/step/get")
|
||||||
|
async def get_agentic_system_step(
|
||||||
|
self, agent_id: str, turn_id: str, step_id: str
|
||||||
|
) -> AgenticSystemStepResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/session/create")
|
||||||
|
async def create_agentic_system_session(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_name: str,
|
||||||
|
) -> AgenticSystemSessionCreateResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/session/get")
|
||||||
|
async def get_agentic_system_session(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_ids: Optional[List[str]] = None,
|
||||||
|
) -> Session: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/session/delete")
|
||||||
|
async def delete_agentic_system_session(
|
||||||
|
self, agent_id: str, session_id: str
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/agentic_system/delete")
|
||||||
|
async def delete_agentic_system(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
) -> None: ...
|
|
@ -1,103 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemCreateResponse(BaseModel):
|
|
||||||
agent_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemSessionCreateResponse(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|
||||||
agent_id: str
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
# TODO: figure out how we can simplify this and make why
|
|
||||||
# ToolResponseMessage needs to be here (it is function call
|
|
||||||
# execution from outside the system)
|
|
||||||
messages: List[
|
|
||||||
Union[
|
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
attachments: Optional[List[Attachment]] = None
|
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type(
|
|
||||||
schema={"description": "Server side event (SSE) stream of these events"}
|
|
||||||
)
|
|
||||||
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
|
||||||
event: AgenticSystemTurnResponseEvent
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemStepResponse(BaseModel):
|
|
||||||
step: Step
|
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystem(Protocol):
|
|
||||||
@webmethod(route="/agentic_system/create")
|
|
||||||
async def create_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_config: AgentConfig,
|
|
||||||
) -> AgenticSystemCreateResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/create")
|
|
||||||
async def create_agentic_system_turn(
|
|
||||||
self,
|
|
||||||
request: AgenticSystemTurnCreateRequest,
|
|
||||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/get")
|
|
||||||
async def get_agentic_system_turn(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
turn_id: str,
|
|
||||||
) -> Turn: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/step/get")
|
|
||||||
async def get_agentic_system_step(
|
|
||||||
self, agent_id: str, turn_id: str, step_id: str
|
|
||||||
) -> AgenticSystemStepResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/create")
|
|
||||||
async def create_agentic_system_session(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_name: str,
|
|
||||||
) -> AgenticSystemSessionCreateResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/get")
|
|
||||||
async def get_agentic_system_session(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
turn_ids: Optional[List[str]] = None,
|
|
||||||
) -> Session: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/delete")
|
|
||||||
async def delete_agentic_system_session(
|
|
||||||
self, agent_id: str, session_id: str
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/delete")
|
|
||||||
async def delete_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
) -> None: ...
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -4,13 +4,34 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Protocol
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainEvalDatasetColumnType(Enum):
|
||||||
|
dialog = "dialog"
|
||||||
|
text = "text"
|
||||||
|
media = "media"
|
||||||
|
number = "number"
|
||||||
|
json = "json"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainEvalDataset(BaseModel):
|
||||||
|
"""Dataset to be used for training or evaluating language models."""
|
||||||
|
|
||||||
|
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
||||||
|
|
||||||
|
columns: Dict[str, TrainEvalDatasetColumnType]
|
||||||
|
content_url: URL
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,34 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainEvalDatasetColumnType(Enum):
|
|
||||||
dialog = "dialog"
|
|
||||||
text = "text"
|
|
||||||
media = "media"
|
|
||||||
number = "number"
|
|
||||||
json = "json"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainEvalDataset(BaseModel):
|
|
||||||
"""Dataset to be used for training or evaluating language models."""
|
|
||||||
|
|
||||||
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
|
||||||
|
|
||||||
columns: Dict[str, TrainEvalDatasetColumnType]
|
|
||||||
content_url: URL
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
|
@ -11,11 +12,34 @@ from llama_models.schema_utils import webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_toolchain.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class TextGenerationMetric(Enum):
|
||||||
|
perplexity = "perplexity"
|
||||||
|
rouge = "rouge"
|
||||||
|
bleu = "bleu"
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAnsweringMetric(Enum):
|
||||||
|
em = "em"
|
||||||
|
f1 = "f1"
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationMetric(Enum):
|
||||||
|
rouge = "rouge"
|
||||||
|
bleu = "bleu"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationJob(BaseModel):
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationJobLogStream(BaseModel):
|
||||||
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
class EvaluateTaskRequestCommon(BaseModel):
|
class EvaluateTaskRequestCommon(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
dataset: TrainEvalDataset
|
dataset: TrainEvalDataset
|
|
@ -1,33 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationMetric(Enum):
|
|
||||||
perplexity = "perplexity"
|
|
||||||
rouge = "rouge"
|
|
||||||
bleu = "bleu"
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionAnsweringMetric(Enum):
|
|
||||||
em = "em"
|
|
||||||
f1 = "f1"
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizationMetric(Enum):
|
|
||||||
rouge = "rouge"
|
|
||||||
bleu = "bleu"
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJob(BaseModel):
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJobLogStream(BaseModel):
|
|
||||||
job_uuid: str
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -4,13 +4,73 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
from enum import Enum
|
||||||
from typing import Optional, Protocol
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
from typing import List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from llama_models.schema_utils import webmethod
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class LogProbConfig(BaseModel):
|
||||||
|
top_k: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QuantizationType(Enum):
|
||||||
|
bf16 = "bf16"
|
||||||
|
fp8 = "fp8"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
|
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Bf16QuantizationConfig(BaseModel):
|
||||||
|
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||||
|
|
||||||
|
|
||||||
|
QuantizationConfig = Annotated[
|
||||||
|
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
start = "start"
|
||||||
|
complete = "complete"
|
||||||
|
progress = "progress"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallParseStatus(Enum):
|
||||||
|
started = "started"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
failure = "failure"
|
||||||
|
success = "success"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallDelta(BaseModel):
|
||||||
|
content: Union[str, ToolCall]
|
||||||
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
"""Chat completion response event."""
|
||||||
|
|
||||||
|
event_type: ChatCompletionResponseEventType
|
||||||
|
delta: Union[str, ToolCallDelta]
|
||||||
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
stop_reason: Optional[StopReason] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,72 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
|
||||||
top_k: Optional[int] = 0
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class QuantizationType(Enum):
|
|
||||||
bf16 = "bf16"
|
|
||||||
fp8 = "fp8"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Fp8QuantizationConfig(BaseModel):
|
|
||||||
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Bf16QuantizationConfig(BaseModel):
|
|
||||||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
|
||||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseEventType(Enum):
|
|
||||||
start = "start"
|
|
||||||
complete = "complete"
|
|
||||||
progress = "progress"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolCallParseStatus(Enum):
|
|
||||||
started = "started"
|
|
||||||
in_progress = "in_progress"
|
|
||||||
failure = "failure"
|
|
||||||
success = "success"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolCallDelta(BaseModel):
|
|
||||||
content: Union[str, ToolCall]
|
|
||||||
parse_status: ToolCallParseStatus
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseEvent(BaseModel):
|
|
||||||
"""Chat completion response event."""
|
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
|
||||||
delta: Union[str, ToolCallDelta]
|
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
|
||||||
stop_reason: Optional[StopReason] = None
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -3,17 +3,21 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryBankDocument(BaseModel):
|
class MemoryBankDocument(BaseModel):
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -5,12 +5,79 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from enum import Enum
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
@json_schema_type
|
||||||
|
class ExperimentStatus(Enum):
|
||||||
|
NOT_STARTED = "not_started"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Experiment(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
status: ExperimentStatus
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Run(BaseModel):
|
||||||
|
id: str
|
||||||
|
experiment_id: str
|
||||||
|
status: str
|
||||||
|
started_at: datetime
|
||||||
|
ended_at: Optional[datetime]
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Metric(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: Union[float, int, str, bool]
|
||||||
|
timestamp: datetime
|
||||||
|
run_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Log(BaseModel):
|
||||||
|
message: str
|
||||||
|
level: str
|
||||||
|
timestamp: datetime
|
||||||
|
additional_info: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ArtifactType(Enum):
|
||||||
|
MODEL = "model"
|
||||||
|
DATASET = "dataset"
|
||||||
|
CHECKPOINT = "checkpoint"
|
||||||
|
PLOT = "plot"
|
||||||
|
METRIC = "metric"
|
||||||
|
CONFIG = "config"
|
||||||
|
CODE = "code"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Artifact(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: ArtifactType
|
||||||
|
size: int
|
||||||
|
created_at: datetime
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,80 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ExperimentStatus(Enum):
|
|
||||||
NOT_STARTED = "not_started"
|
|
||||||
RUNNING = "running"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Experiment(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
status: ExperimentStatus
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Run(BaseModel):
|
|
||||||
id: str
|
|
||||||
experiment_id: str
|
|
||||||
status: str
|
|
||||||
started_at: datetime
|
|
||||||
ended_at: Optional[datetime]
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Metric(BaseModel):
|
|
||||||
name: str
|
|
||||||
value: Union[float, int, str, bool]
|
|
||||||
timestamp: datetime
|
|
||||||
run_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Log(BaseModel):
|
|
||||||
message: str
|
|
||||||
level: str
|
|
||||||
timestamp: datetime
|
|
||||||
additional_info: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ArtifactType(Enum):
|
|
||||||
MODEL = "model"
|
|
||||||
DATASET = "dataset"
|
|
||||||
CHECKPOINT = "checkpoint"
|
|
||||||
PLOT = "plot"
|
|
||||||
METRIC = "metric"
|
|
||||||
CONFIG = "config"
|
|
||||||
CODE = "code"
|
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Artifact(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: ArtifactType
|
|
||||||
size: int
|
|
||||||
created_at: datetime
|
|
||||||
metadata: Dict[str, Any]
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
@ -15,7 +16,88 @@ from pydantic import BaseModel, Field
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_toolchain.common.training_types import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
class OptimizerType(Enum):
|
||||||
|
adam = "adam"
|
||||||
|
adamw = "adamw"
|
||||||
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OptimizerConfig(BaseModel):
|
||||||
|
optimizer_type: OptimizerType
|
||||||
|
lr: float
|
||||||
|
lr_min: float
|
||||||
|
weight_decay: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
n_epochs: int
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
n_iters: int
|
||||||
|
|
||||||
|
enable_activation_checkpointing: bool
|
||||||
|
memory_efficient_fsdp_wrap: bool
|
||||||
|
fsdp_cpu_offload: bool
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FinetuningAlgorithm(Enum):
|
||||||
|
full = "full"
|
||||||
|
lora = "lora"
|
||||||
|
qlora = "qlora"
|
||||||
|
dora = "dora"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LoraFinetuningConfig(BaseModel):
|
||||||
|
lora_attn_modules: List[str]
|
||||||
|
apply_lora_to_mlp: bool
|
||||||
|
apply_lora_to_output: bool
|
||||||
|
rank: int
|
||||||
|
alpha: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DoraFinetuningConfig(LoraFinetuningConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobLogStream(BaseModel):
|
||||||
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
|
job_uuid: str
|
||||||
|
log_lines: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingJobStatus(Enum):
|
||||||
|
running = "running"
|
||||||
|
completed = "completed"
|
||||||
|
failed = "failed"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RLHFAlgorithm(Enum):
|
||||||
|
dpo = "dpo"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DPOAlignmentConfig(BaseModel):
|
||||||
|
reward_scale: float
|
||||||
|
reward_clip: float
|
||||||
|
epsilon: float
|
||||||
|
gamma: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,94 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerType(Enum):
|
|
||||||
adam = "adam"
|
|
||||||
adamw = "adamw"
|
|
||||||
sgd = "sgd"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OptimizerConfig(BaseModel):
|
|
||||||
optimizer_type: OptimizerType
|
|
||||||
lr: float
|
|
||||||
lr_min: float
|
|
||||||
weight_decay: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainingConfig(BaseModel):
|
|
||||||
n_epochs: int
|
|
||||||
batch_size: int
|
|
||||||
shuffle: bool
|
|
||||||
n_iters: int
|
|
||||||
|
|
||||||
enable_activation_checkpointing: bool
|
|
||||||
memory_efficient_fsdp_wrap: bool
|
|
||||||
fsdp_cpu_offload: bool
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class FinetuningAlgorithm(Enum):
|
|
||||||
full = "full"
|
|
||||||
lora = "lora"
|
|
||||||
qlora = "qlora"
|
|
||||||
dora = "dora"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LoraFinetuningConfig(BaseModel):
|
|
||||||
lora_attn_modules: List[str]
|
|
||||||
apply_lora_to_mlp: bool
|
|
||||||
apply_lora_to_output: bool
|
|
||||||
rank: int
|
|
||||||
alpha: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DoraFinetuningConfig(LoraFinetuningConfig):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobLogStream(BaseModel):
|
|
||||||
"""Stream of logs from a finetuning job."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
log_lines: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobStatus(Enum):
|
|
||||||
running = "running"
|
|
||||||
completed = "completed"
|
|
||||||
failed = "failed"
|
|
||||||
scheduled = "scheduled"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RLHFAlgorithm(Enum):
|
|
||||||
dpo = "dpo"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DPOAlignmentConfig(BaseModel):
|
|
||||||
reward_scale: float
|
|
||||||
reward_clip: float
|
|
||||||
epsilon: float
|
|
||||||
gamma: float
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -5,9 +5,30 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol, Union
|
from typing import List, Protocol, Union
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoredMessage(BaseModel):
|
||||||
|
message: Message
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DialogGenerations(BaseModel):
|
||||||
|
dialog: List[Message]
|
||||||
|
sampled_generations: List[Message]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScoredDialogGenerations(BaseModel):
|
||||||
|
dialog: List[Message]
|
||||||
|
scored_generations: List[ScoredMessage]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,31 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoredMessage(BaseModel):
|
|
||||||
message: Message
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DialogGenerations(BaseModel):
|
|
||||||
dialog: List[Message]
|
|
||||||
sampled_generations: List[Message]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoredDialogGenerations(BaseModel):
|
|
||||||
dialog: List[Message]
|
|
||||||
scored_generations: List[ScoredMessage]
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa
|
|
||||||
|
|
|
@ -5,13 +5,12 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,3 +69,22 @@ class ShieldResponse(BaseModel):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return v
|
return v
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RunShieldRequest(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
shields: List[ShieldDefinition]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RunShieldResponse(BaseModel):
|
||||||
|
responses: List[ShieldResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class Safety(Protocol):
|
||||||
|
@webmethod(route="/safety/run_shields")
|
||||||
|
async def run_shields(
|
||||||
|
self,
|
||||||
|
request: RunShieldRequest,
|
||||||
|
) -> RunShieldResponse: ...
|
|
@ -1,32 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
from typing import List, Protocol
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
|
||||||
from llama_models.schema_utils import webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RunShieldRequest(BaseModel):
|
|
||||||
messages: List[Message]
|
|
||||||
shields: List[ShieldDefinition]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RunShieldResponse(BaseModel):
|
|
||||||
responses: List[ShieldResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class Safety(Protocol):
|
|
||||||
@webmethod(route="/safety/run_shields")
|
|
||||||
async def run_shields(
|
|
||||||
self,
|
|
||||||
request: RunShieldRequest,
|
|
||||||
) -> RunShieldResponse: ...
|
|
|
@ -4,5 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F401 F403
|
from .api import * # noqa: F401 F403
|
||||||
from .endpoints import * # noqa: F401 F403
|
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -12,7 +14,17 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
|
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
class FilteringFunction(Enum):
|
||||||
|
"""The type of filtering function."""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
random = "random"
|
||||||
|
top_k = "top_k"
|
||||||
|
top_p = "top_p"
|
||||||
|
top_k_top_p = "top_k_top_p"
|
||||||
|
sigmoid = "sigmoid"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -1,18 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
|
||||||
"""The type of filtering function."""
|
|
||||||
|
|
||||||
none = "none"
|
|
||||||
random = "random"
|
|
||||||
top_k = "top_k"
|
|
||||||
top_p = "top_p"
|
|
||||||
top_k_top_p = "top_k_top_p"
|
|
||||||
sigmoid = "sigmoid"
|
|
Loading…
Add table
Add a link
Reference in a new issue