mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:12:26 +00:00
Merge branch 'main' into rag_scoring_fn_1
This commit is contained in:
commit
d62f1040fe
128 changed files with 6391 additions and 493 deletions
|
|
@ -18,18 +18,30 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
ToolCallDelta,
|
||||
ToolChoice,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||
|
||||
from llama_stack.apis.inference import ToolResponseMessage
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -10,8 +10,16 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_serializer, model_validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -27,6 +28,12 @@ class _URLOrData(BaseModel):
|
|||
return values
|
||||
return {"url": values}
|
||||
|
||||
@field_serializer("data")
|
||||
def serialize_data(self, data: Optional[bytes], _info):
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageContentItem(_URLOrData):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import BaseModel, Field
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@
|
|||
from enum import Enum
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
|
|
@ -32,8 +34,9 @@ from typing_extensions import Annotated
|
|||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -7,17 +7,17 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ class ResourceType(Enum):
|
|||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
eval_task = "eval_task"
|
||||
tool = "tool"
|
||||
tool_group = "tool_group"
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
|
||||
|
||||
# mapping of metric to value
|
||||
|
|
|
|||
|
|
@ -6,13 +6,12 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
|
||||
|
|
|
|||
7
llama_stack/apis/tools/__init__.py
Normal file
7
llama_stack/apis/tools/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .tools import * # noqa: F401 F403
|
||||
141
llama_stack/apis/tools/tools.py
Normal file
141
llama_stack/apis/tools/tools.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
tool_group: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
provider_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
metadata: Dict[str, Any]
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
A tool group that is defined by in a model context protocol server.
|
||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
||||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
endpoint: URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolGroupDef(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
tools: List[ToolDef]
|
||||
|
||||
|
||||
ToolGroupDef = register_schema(
|
||||
Annotated[
|
||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
||||
],
|
||||
name="ToolGroup",
|
||||
)
|
||||
|
||||
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
content: InterleavedContent
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/get", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
) -> ToolGroup: ...
|
||||
|
||||
@webmethod(route="/toolgroups/list", method="GET")
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
"""List tool groups with optional provider"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/list", method="GET")
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
"""List tools with optional tool group"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/get", method="GET")
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
|
||||
@webmethod(route="/toolgroups/unregister", method="POST")
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
"""Unregister a tool group"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
...
|
||||
Loading…
Add table
Add a link
Reference in a new issue