forked from phoenix-oss/llama-stack-mirror
chore: move all Llama Stack types from llama-models to llama-stack (#1098)
llama-models should have extremely minimal cruft. Its sole purpose should be didactic -- show the simplest implementation of the llama models and document the prompt formats, etc. This PR is the complement to https://github.com/meta-llama/llama-models/pull/279 ## Test Plan Ensure all `llama` CLI `model` sub-commands work: ```bash llama model list llama model download --model-id ... llama model prompt-format -m ... ``` Ran tests: ```bash cd tests/client-sdk LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/ LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/ LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/ ``` Create a fresh venv `uv venv && source .venv/bin/activate` and run `llama stack build --template fireworks --image-type venv` followed by `llama stack run together --image-type venv` <-- the server runs Also checked that the OpenAPI generator can run and there is no change in the generated files as a result. ```bash cd docs/openapi_generator sh run_openapi_generator.sh ```
This commit is contained in:
parent
c0ee512980
commit
314ee09ae3
138 changed files with 8491 additions and 465 deletions
|
@ -30,6 +30,7 @@ repos:
|
||||||
rev: v0.9.4
|
rev: v0.9.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
exclude: ^llama_stack/strong_typing/.*$
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
- repo: https://github.com/adamchainz/blacken-docs
|
- repo: https://github.com/adamchainz/blacken-docs
|
||||||
|
@ -43,7 +44,13 @@ repos:
|
||||||
rev: 0.5.26
|
rev: 0.5.26
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
args: ["--frozen", "--no-hashes", "--no-emit-project"]
|
args: [
|
||||||
|
"--frozen",
|
||||||
|
"--no-hashes",
|
||||||
|
"--no-emit-project",
|
||||||
|
"--output-file=requirements.txt"
|
||||||
|
]
|
||||||
|
files: ^pyproject\.toml$
|
||||||
- id: uv-sync
|
- id: uv-sync
|
||||||
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
|
|
@ -16,18 +16,6 @@ from pathlib import Path
|
||||||
import fire
|
import fire
|
||||||
import ruamel.yaml as yaml
|
import ruamel.yaml as yaml
|
||||||
|
|
||||||
from llama_models import schema_utils
|
|
||||||
|
|
||||||
# We do some monkey-patching to ensure our definitions only use the minimal
|
|
||||||
# (json_schema_type, webmethod) definitions from the llama_models package. For
|
|
||||||
# generation though, we need the full definitions and implementations from the
|
|
||||||
# (json-strong-typing) package.
|
|
||||||
|
|
||||||
from .strong_typing.schema import json_schema_type, register_schema
|
|
||||||
|
|
||||||
schema_utils.json_schema_type = json_schema_type
|
|
||||||
schema_utils.register_schema = register_schema
|
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
|
||||||
from llama_stack.distribution.stack import LlamaStack # noqa: E402
|
from llama_stack.distribution.stack import LlamaStack # noqa: E402
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
|
||||||
from ..strong_typing.core import JsonType
|
from llama_stack.strong_typing.core import JsonType
|
||||||
from ..strong_typing.docstring import Docstring, parse_type
|
from llama_stack.strong_typing.docstring import Docstring, parse_type
|
||||||
from ..strong_typing.inspection import (
|
from llama_stack.strong_typing.inspection import (
|
||||||
is_generic_list,
|
is_generic_list,
|
||||||
is_type_optional,
|
is_type_optional,
|
||||||
is_type_union,
|
is_type_union,
|
||||||
|
@ -20,15 +20,15 @@ from ..strong_typing.inspection import (
|
||||||
unwrap_optional_type,
|
unwrap_optional_type,
|
||||||
unwrap_union_types,
|
unwrap_union_types,
|
||||||
)
|
)
|
||||||
from ..strong_typing.name import python_type_to_name
|
from llama_stack.strong_typing.name import python_type_to_name
|
||||||
from ..strong_typing.schema import (
|
from llama_stack.strong_typing.schema import (
|
||||||
get_schema_identifier,
|
get_schema_identifier,
|
||||||
JsonSchemaGenerator,
|
JsonSchemaGenerator,
|
||||||
register_schema,
|
register_schema,
|
||||||
Schema,
|
Schema,
|
||||||
SchemaOptions,
|
SchemaOptions,
|
||||||
)
|
)
|
||||||
from ..strong_typing.serialization import json_dump_string, object_to_json
|
from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
|
||||||
|
|
||||||
from .operations import (
|
from .operations import (
|
||||||
EndpointOperation,
|
EndpointOperation,
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from ..strong_typing.inspection import get_signature
|
from llama_stack.strong_typing.inspection import get_signature
|
||||||
|
|
||||||
|
|
||||||
def split_prefix(
|
def split_prefix(
|
||||||
|
|
|
@ -9,7 +9,7 @@ import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||||
|
|
||||||
from ..strong_typing.schema import JsonType, Schema, StrictJsonType
|
from llama_stack.strong_typing.schema import JsonType, Schema, StrictJsonType
|
||||||
|
|
||||||
URL = str
|
URL = str
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ import typing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
from ..strong_typing.schema import object_to_json, StrictJsonType
|
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
|
||||||
|
|
||||||
from .generator import Generator
|
from .generator import Generator
|
||||||
from .options import Options
|
from .options import Options
|
||||||
|
|
|
@ -19,7 +19,6 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
|
@ -38,6 +37,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
|
|
|
@ -1,206 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
|
||||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
|
||||||
from llama_stack.apis.inference import ToolResponseMessage
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
role: Optional[str] = None,
|
|
||||||
content: str = "",
|
|
||||||
end: str = "\n",
|
|
||||||
color="white",
|
|
||||||
):
|
|
||||||
self.role = role
|
|
||||||
self.content = content
|
|
||||||
self.color = color
|
|
||||||
self.end = "\n" if end is None else end
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.role is not None:
|
|
||||||
return f"{self.role}> {self.content}"
|
|
||||||
else:
|
|
||||||
return f"{self.content}"
|
|
||||||
|
|
||||||
def print(self, flush=True):
|
|
||||||
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
|
||||||
|
|
||||||
|
|
||||||
EventType = AgentTurnResponseEventType
|
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
|
||||||
async def log(
|
|
||||||
self,
|
|
||||||
event_generator,
|
|
||||||
stream=True,
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
):
|
|
||||||
previous_event_type = None
|
|
||||||
previous_step_type = None
|
|
||||||
|
|
||||||
async for chunk in event_generator:
|
|
||||||
if not hasattr(chunk, "event"):
|
|
||||||
# Need to check for custom tool first
|
|
||||||
# since it does not produce event but instead
|
|
||||||
# a Message
|
|
||||||
if isinstance(chunk, ToolResponseMessage):
|
|
||||||
yield (
|
|
||||||
chunk,
|
|
||||||
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
event = chunk.event
|
|
||||||
event_type = event.payload.event_type
|
|
||||||
if event_type in {
|
|
||||||
EventType.turn_start.value,
|
|
||||||
EventType.turn_complete.value,
|
|
||||||
}:
|
|
||||||
# Currently not logging any turn realted info
|
|
||||||
yield event, None
|
|
||||||
continue
|
|
||||||
|
|
||||||
step_type = event.payload.step_type
|
|
||||||
# handle safety
|
|
||||||
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
|
|
||||||
violation = event.payload.step_details.violation
|
|
||||||
if not violation:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(role=step_type, content="No Violation", color="magenta"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"{violation.metadata} {violation.user_message}",
|
|
||||||
color="red",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle inference
|
|
||||||
if step_type == StepType.inference:
|
|
||||||
if stream:
|
|
||||||
if event_type == EventType.step_start.value:
|
|
||||||
# TODO: Currently this event is never received
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(role=step_type, content="", end="", color="yellow"),
|
|
||||||
)
|
|
||||||
elif event_type == EventType.step_progress.value:
|
|
||||||
# HACK: if previous was not step/event was not inference's step_progress
|
|
||||||
# this is the first time we are getting model inference response
|
|
||||||
# aka equivalent to step_start for inference. Hence,
|
|
||||||
# start with "Model>".
|
|
||||||
if (
|
|
||||||
previous_event_type != EventType.step_progress.value
|
|
||||||
and previous_step_type != StepType.inference
|
|
||||||
):
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(role=step_type, content="", end="", color="yellow"),
|
|
||||||
)
|
|
||||||
|
|
||||||
delta = event.payload.delta
|
|
||||||
if delta.type == "tool_call":
|
|
||||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=None,
|
|
||||||
content=delta.tool_call,
|
|
||||||
end="",
|
|
||||||
color="cyan",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=None,
|
|
||||||
content=delta.text,
|
|
||||||
end="",
|
|
||||||
color="yellow",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# step_complete
|
|
||||||
yield event, LogEvent(role=None, content="")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Not streaming
|
|
||||||
if event_type == EventType.step_complete.value:
|
|
||||||
response = event.payload.step_details.model_response
|
|
||||||
if response.tool_calls:
|
|
||||||
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
|
|
||||||
else:
|
|
||||||
content = response.content
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=content,
|
|
||||||
color="yellow",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle tool_execution
|
|
||||||
if (
|
|
||||||
step_type == StepType.tool_execution
|
|
||||||
and
|
|
||||||
# Only print tool calls and responses at the step_complete event
|
|
||||||
event_type == EventType.step_complete.value
|
|
||||||
):
|
|
||||||
details = event.payload.step_details
|
|
||||||
for t in details.tool_calls:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
|
||||||
color="green",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for r in details.tool_responses:
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
|
||||||
color="green",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
|
|
||||||
details = event.payload.step_details
|
|
||||||
inserted_context = interleaved_content_as_str(details.inserted_context)
|
|
||||||
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
|
|
||||||
|
|
||||||
yield (
|
|
||||||
event,
|
|
||||||
LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=content,
|
|
||||||
color="cyan",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
previous_event_type = event_type
|
|
||||||
previous_step_type = step_type
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -5,10 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonBenchmarkFields(BaseModel):
|
class CommonBenchmarkFields(BaseModel):
|
||||||
|
|
|
@ -7,10 +7,11 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolCall
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import ToolCall
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class URL(BaseModel):
|
class URL(BaseModel):
|
||||||
|
|
|
@ -7,10 +7,10 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -5,9 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Job(BaseModel):
|
class Job(BaseModel):
|
||||||
|
|
|
@ -7,9 +7,10 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingMetric(BaseModel):
|
class PostTrainingMetric(BaseModel):
|
||||||
|
|
|
@ -6,10 +6,11 @@
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class StringType(BaseModel):
|
class StringType(BaseModel):
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonDatasetFields(BaseModel):
|
class CommonDatasetFields(BaseModel):
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus
|
||||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||||
from llama_stack.apis.scoring import ScoringResult
|
from llama_stack.apis.scoring import ScoringResult
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -17,7 +17,13 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -25,14 +31,8 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
|
||||||
from llama_stack.apis.models import Model
|
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import List, Protocol, runtime_checkable
|
from typing import List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderInfo(BaseModel):
|
class ProviderInfo(BaseModel):
|
||||||
|
|
|
@ -7,11 +7,11 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
class CommonModelFields(BaseModel):
|
||||||
|
|
|
@ -8,13 +8,13 @@ from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
# mapping of metric to value
|
# mapping of metric to value
|
||||||
ScoringResultRow = Dict[str, Any]
|
ScoringResultRow = Dict[str, Any]
|
||||||
|
|
|
@ -16,12 +16,12 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonShieldFields(BaseModel):
|
class CommonShieldFields(BaseModel):
|
||||||
|
|
|
@ -7,10 +7,10 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
class FilteringFunction(Enum):
|
||||||
|
|
|
@ -17,11 +17,12 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Primitive
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import Primitive
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
# Add this constant near the top of the file, after the imports
|
# Add this constant near the top of the file, after the imports
|
||||||
DEFAULT_TTL_DAYS = 7
|
DEFAULT_TTL_DAYS = 7
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -7,13 +7,13 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from .rag_tool import RAGToolRuntime
|
from .rag_tool import RAGToolRuntime
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -10,12 +10,12 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
|
|
|
@ -16,8 +16,6 @@ from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import Model
|
|
||||||
from llama_models.sku_list import LlamaDownloadInfo
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
|
@ -31,6 +29,8 @@ from rich.progress import (
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.models.llama.datatypes import Model
|
||||||
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
|
@ -454,7 +454,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
# Handle comma-separated model IDs
|
# Handle comma-separated model IDs
|
||||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||||
|
|
||||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||||
|
|
||||||
from .model.safety_models import (
|
from .model.safety_models import (
|
||||||
prompt_guard_download_info,
|
prompt_guard_download_info,
|
||||||
|
|
|
@ -7,11 +7,11 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
|
|
||||||
class ModelDescribe(Subcommand):
|
class ModelDescribe(Subcommand):
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
|
|
||||||
|
|
||||||
class ModelList(Subcommand):
|
class ModelList(Subcommand):
|
||||||
|
|
|
@ -8,9 +8,8 @@ import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||||
|
|
||||||
|
|
||||||
class ModelPromptFormat(Subcommand):
|
class ModelPromptFormat(Subcommand):
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
|
||||||
from llama_models.llama3.api.datatypes import SamplingParams
|
|
||||||
from llama_models.sku_list import LlamaDownloadInfo
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
|
||||||
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardModel(BaseModel):
|
class PromptGuardModel(BaseModel):
|
||||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||||
|
|
|
@ -186,33 +186,3 @@ def extract_async_iterator_type(type_hint):
|
||||||
inner_args = get_args(arg)
|
inner_args = get_args(arg)
|
||||||
return inner_args[0]
|
return inner_args[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def example(model: str = None):
|
|
||||||
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
|
|
||||||
from llama_stack.apis.inference.event_logger import EventLogger
|
|
||||||
|
|
||||||
client_class = create_api_client_class(Inference)
|
|
||||||
client = client_class("http://localhost:5003")
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
model = "Llama3.2-3B-Instruct"
|
|
||||||
|
|
||||||
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
|
|
||||||
cprint(f"User>{message.content}", "green")
|
|
||||||
|
|
||||||
stream = True
|
|
||||||
iterator = await client.chat_completion(
|
|
||||||
model=model,
|
|
||||||
messages=[message],
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
async for log in EventLogger().log(iterator):
|
|
||||||
log.print()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(example())
|
|
||||||
|
|
277
llama_stack/models/llama/datatypes.py
Normal file
277
llama_stack/models/llama/datatypes.py
Normal file
|
@ -0,0 +1,277 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
# import all for backwards compatibility
|
||||||
|
from llama_models.datatypes import * # noqa: F403
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
register_schema(ToolCall)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolParamDefinition(BaseModel):
|
||||||
|
param_type: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
required: Optional[bool] = True
|
||||||
|
default: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolDefinition(BaseModel):
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
type: Literal["greedy"] = "greedy"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopPSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_p"] = "top_p"
|
||||||
|
temperature: Optional[float] = Field(..., gt=0.0)
|
||||||
|
top_p: Optional[float] = 0.95
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopKSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_k"] = "top_k"
|
||||||
|
top_k: int = Field(..., ge=1)
|
||||||
|
|
||||||
|
|
||||||
|
SamplingStrategy = register_schema(
|
||||||
|
Annotated[
|
||||||
|
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
],
|
||||||
|
name="SamplingStrategy",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SamplingParams(BaseModel):
|
||||||
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = 0
|
||||||
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointQuantizationFormat(Enum):
|
||||||
|
# default format
|
||||||
|
bf16 = "bf16"
|
||||||
|
|
||||||
|
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||||
|
fp8_mixed = "fp8-mixed"
|
||||||
|
|
||||||
|
int8 = "int8"
|
||||||
|
|
||||||
|
int4 = "int4"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFamily(Enum):
|
||||||
|
llama2 = "llama2"
|
||||||
|
llama3 = "llama3"
|
||||||
|
llama3_1 = "llama3_1"
|
||||||
|
llama3_2 = "llama3_2"
|
||||||
|
llama3_3 = "llama3_3"
|
||||||
|
safety = "safety"
|
||||||
|
|
||||||
|
|
||||||
|
class CoreModelId(Enum):
|
||||||
|
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||||
|
|
||||||
|
# Llama 2 family
|
||||||
|
llama2_7b = "Llama-2-7b"
|
||||||
|
llama2_13b = "Llama-2-13b"
|
||||||
|
llama2_70b = "Llama-2-70b"
|
||||||
|
llama2_7b_chat = "Llama-2-7b-chat"
|
||||||
|
llama2_13b_chat = "Llama-2-13b-chat"
|
||||||
|
llama2_70b_chat = "Llama-2-70b-chat"
|
||||||
|
|
||||||
|
# Llama 3 family
|
||||||
|
llama3_8b = "Llama-3-8B"
|
||||||
|
llama3_70b = "Llama-3-70B"
|
||||||
|
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||||
|
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.1 family
|
||||||
|
llama3_1_8b = "Llama3.1-8B"
|
||||||
|
llama3_1_70b = "Llama3.1-70B"
|
||||||
|
llama3_1_405b = "Llama3.1-405B"
|
||||||
|
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||||
|
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||||
|
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.2 family
|
||||||
|
llama3_2_1b = "Llama3.2-1B"
|
||||||
|
llama3_2_3b = "Llama3.2-3B"
|
||||||
|
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||||
|
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||||
|
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||||
|
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||||
|
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||||
|
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.3 family
|
||||||
|
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||||
|
|
||||||
|
# Safety models
|
||||||
|
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||||
|
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||||
|
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||||
|
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||||
|
|
||||||
|
|
||||||
|
def is_multimodal(model_id) -> bool:
|
||||||
|
if model_id in [
|
||||||
|
CoreModelId.llama3_2_11b_vision,
|
||||||
|
CoreModelId.llama3_2_90b_vision,
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct,
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct,
|
||||||
|
]:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def model_family(model_id) -> ModelFamily:
|
||||||
|
if model_id in [
|
||||||
|
CoreModelId.llama2_7b,
|
||||||
|
CoreModelId.llama2_13b,
|
||||||
|
CoreModelId.llama2_70b,
|
||||||
|
CoreModelId.llama2_7b_chat,
|
||||||
|
CoreModelId.llama2_13b_chat,
|
||||||
|
CoreModelId.llama2_70b_chat,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama2
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_8b,
|
||||||
|
CoreModelId.llama3_70b,
|
||||||
|
CoreModelId.llama3_8b_instruct,
|
||||||
|
CoreModelId.llama3_70b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_1_8b,
|
||||||
|
CoreModelId.llama3_1_70b,
|
||||||
|
CoreModelId.llama3_1_405b,
|
||||||
|
CoreModelId.llama3_1_8b_instruct,
|
||||||
|
CoreModelId.llama3_1_70b_instruct,
|
||||||
|
CoreModelId.llama3_1_405b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_1
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_2_1b,
|
||||||
|
CoreModelId.llama3_2_3b,
|
||||||
|
CoreModelId.llama3_2_1b_instruct,
|
||||||
|
CoreModelId.llama3_2_3b_instruct,
|
||||||
|
CoreModelId.llama3_2_11b_vision,
|
||||||
|
CoreModelId.llama3_2_90b_vision,
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct,
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_2
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_3_70b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_3
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama_guard_3_8b,
|
||||||
|
CoreModelId.llama_guard_2_8b,
|
||||||
|
CoreModelId.llama_guard_3_11b_vision,
|
||||||
|
CoreModelId.llama_guard_3_1b,
|
||||||
|
]:
|
||||||
|
return ModelFamily.safety
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model family for {model_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
core_model_id: CoreModelId
|
||||||
|
description: str
|
||||||
|
huggingface_repo: Optional[str] = None
|
||||||
|
recommended_sampling_params: Optional[SamplingParams] = None
|
||||||
|
arch_args: Dict[str, Any]
|
||||||
|
variant: str = ""
|
||||||
|
|
||||||
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
|
pth_file_count: int
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
# silence pydantic until we remove the `model_` fields
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_family(self) -> ModelFamily:
|
||||||
|
return model_family(self.core_model_id)
|
||||||
|
|
||||||
|
# The SKU is uniquely identified by (model_id, variant) combo
|
||||||
|
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||||
|
if not self.variant:
|
||||||
|
return self.core_model_id.value
|
||||||
|
return f"{self.core_model_id.value}:{self.variant}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_instruct_model(self) -> bool:
|
||||||
|
return "instruct" in self.id.name
|
||||||
|
|
||||||
|
# Featured models are shown in the non-exhaustive model list
|
||||||
|
@property
|
||||||
|
def is_featured(self) -> bool:
|
||||||
|
return self.model_family in [
|
||||||
|
ModelFamily.llama3_1,
|
||||||
|
ModelFamily.llama3_2,
|
||||||
|
ModelFamily.llama3_3,
|
||||||
|
ModelFamily.safety,
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_seq_length(self) -> int:
|
||||||
|
if self.model_family == ModelFamily.llama2:
|
||||||
|
return 4096
|
||||||
|
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||||
|
return 4096
|
||||||
|
elif self.model_family == ModelFamily.llama3:
|
||||||
|
return 8192
|
||||||
|
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||||
|
return 131072
|
||||||
|
elif self.model_family == ModelFamily.llama3_2:
|
||||||
|
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||||
|
return 8192
|
||||||
|
return 131072
|
||||||
|
elif self.core_model_id in [
|
||||||
|
CoreModelId.llama_guard_3_8b,
|
||||||
|
CoreModelId.llama_guard_3_11b_vision,
|
||||||
|
CoreModelId.llama_guard_3_1b,
|
||||||
|
]:
|
||||||
|
return 131072
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
BIN
llama_stack/models/llama/llama3/dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
257
llama_stack/models/llama/llama3/interface.py
Normal file
257
llama_stack/models/llama/llama3/interface.py
Normal file
|
@ -0,0 +1,257 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||||
|
|
||||||
|
from . import template_data
|
||||||
|
from .prompt_templates import (
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
FunctionTagCustomToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
ToolResponseGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
class Template:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role,
|
||||||
|
template_name,
|
||||||
|
data_provider=None,
|
||||||
|
notes=None,
|
||||||
|
):
|
||||||
|
self.role = role
|
||||||
|
self.template_name = template_name
|
||||||
|
self.data_provider = data_provider or ""
|
||||||
|
self._notes = notes or ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def notes(self):
|
||||||
|
default = "↵ represents newline"
|
||||||
|
notes = default
|
||||||
|
if self._notes:
|
||||||
|
notes += "\n"
|
||||||
|
notes += self._notes
|
||||||
|
return notes
|
||||||
|
|
||||||
|
|
||||||
|
TEMPLATES = [
|
||||||
|
Template(
|
||||||
|
"user",
|
||||||
|
"user-default",
|
||||||
|
"user_default",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"user",
|
||||||
|
"user-images",
|
||||||
|
"user_images",
|
||||||
|
),
|
||||||
|
Template("user", "user-interleaved-images", "user_interleaved_images"),
|
||||||
|
Template(
|
||||||
|
"assistant",
|
||||||
|
"assistant-builtin-tool-call",
|
||||||
|
"assistant_builtin_tool_call",
|
||||||
|
"Notice <|python_tag|>",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"assistant",
|
||||||
|
"assistant-custom-tool-call",
|
||||||
|
"assistant_custom_tool_call",
|
||||||
|
"Notice <function=...> format",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"assistant",
|
||||||
|
"assistant-default",
|
||||||
|
"assistant_default",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-builtin-and-custom-tools",
|
||||||
|
"system_message_builtin_and_custom_tools",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-builtin-tools-only",
|
||||||
|
"system_message_builtin_tools_only",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-custom-tools-only",
|
||||||
|
"system_message_custom_tools_only",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"system",
|
||||||
|
"system-default",
|
||||||
|
"system_default",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"tool",
|
||||||
|
"tool-success",
|
||||||
|
"tool_success",
|
||||||
|
"Note ipython header and [stdout]",
|
||||||
|
),
|
||||||
|
Template(
|
||||||
|
"tool",
|
||||||
|
"tool-failure",
|
||||||
|
"tool_failure",
|
||||||
|
"Note ipython header and [stderr]",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LLama31Interface:
|
||||||
|
def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json):
|
||||||
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
self.tool_prompt_format = tool_prompt_format
|
||||||
|
|
||||||
|
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
|
messages,
|
||||||
|
self.tool_prompt_format,
|
||||||
|
)
|
||||||
|
return model_input.tokens
|
||||||
|
|
||||||
|
def tool_response_messages(self, *args, **kwargs):
|
||||||
|
template = ToolResponseGenerator().gen(*args, **kwargs)
|
||||||
|
return [
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=template.render(),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def system_messages(
|
||||||
|
self,
|
||||||
|
builtin_tools: List[BuiltinTool],
|
||||||
|
custom_tools: List[ToolDefinition],
|
||||||
|
instruction: Optional[str] = None,
|
||||||
|
) -> List[RawMessage]:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
default_gen = SystemDefaultGenerator()
|
||||||
|
default_template = default_gen.gen()
|
||||||
|
|
||||||
|
sys_content = ""
|
||||||
|
|
||||||
|
tool_template = None
|
||||||
|
if builtin_tools or custom_tools:
|
||||||
|
tool_gen = BuiltinToolGenerator()
|
||||||
|
tool_template = tool_gen.gen(builtin_tools + custom_tools)
|
||||||
|
|
||||||
|
sys_content += tool_template.render()
|
||||||
|
sys_content += "\n"
|
||||||
|
|
||||||
|
sys_content += default_template.render()
|
||||||
|
|
||||||
|
if instruction:
|
||||||
|
sys_content += "\n\n"
|
||||||
|
sys_content += instruction
|
||||||
|
|
||||||
|
sys_content += "\n"
|
||||||
|
messages.append(RawMessage(role="system", content=sys_content))
|
||||||
|
|
||||||
|
if custom_tools:
|
||||||
|
if self.tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
tool_gen = JsonCustomToolGenerator()
|
||||||
|
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
|
tool_gen = FunctionTagCustomToolGenerator()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
|
||||||
|
|
||||||
|
custom_template = tool_gen.gen(custom_tools)
|
||||||
|
messages.append(RawMessage(role="user", content=custom_template.render()))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def assistant_response_messages(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
stop_reason: StopReason,
|
||||||
|
tool_call: Optional[ToolCall] = None,
|
||||||
|
) -> List[RawMessage]:
|
||||||
|
tool_calls = []
|
||||||
|
if tool_call:
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
return [
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def user_message(self, content: str) -> List[RawMessage]:
|
||||||
|
return [RawMessage(role="user", content=content)]
|
||||||
|
|
||||||
|
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||||
|
"""Util to print tokenized string to shell"""
|
||||||
|
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
|
||||||
|
on_colors = [
|
||||||
|
"on_red",
|
||||||
|
"on_green",
|
||||||
|
"on_yellow",
|
||||||
|
"on_blue",
|
||||||
|
"on_magenta",
|
||||||
|
"on_cyan",
|
||||||
|
]
|
||||||
|
for i, t in enumerate(tokens):
|
||||||
|
on_col = on_colors[i % len(on_colors)]
|
||||||
|
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
|
||||||
|
print("\n", end="")
|
||||||
|
|
||||||
|
|
||||||
|
def list_jinja_templates() -> List[Template]:
|
||||||
|
return TEMPLATES
|
||||||
|
|
||||||
|
|
||||||
|
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
|
||||||
|
by_name = {t.template_name: t for t in TEMPLATES}
|
||||||
|
if name not in by_name:
|
||||||
|
raise ValueError(f"No template found for `{name}`")
|
||||||
|
|
||||||
|
template = by_name[name]
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
data_func = getattr(template_data, template.data_provider)
|
||||||
|
if template.role == "system":
|
||||||
|
messages = interface.system_messages(**data_func())
|
||||||
|
elif template.role == "tool":
|
||||||
|
messages = interface.tool_response_messages(**data_func())
|
||||||
|
elif template.role == "assistant":
|
||||||
|
messages = interface.assistant_response_messages(**data_func())
|
||||||
|
elif template.role == "user":
|
||||||
|
messages = interface.user_message(**data_func())
|
||||||
|
|
||||||
|
tokens = interface.get_tokens(messages)
|
||||||
|
special_tokens = list(interface.tokenizer.special_tokens.values())
|
||||||
|
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
|
||||||
|
return template, tokens
|
BIN
llama_stack/models/llama/llama3/pasta.jpeg
Normal file
BIN
llama_stack/models/llama/llama3/pasta.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
22
llama_stack/models/llama/llama3/prompt_templates/__init__.py
Normal file
22
llama_stack/models/llama/llama3/prompt_templates/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from .base import PromptTemplate, PromptTemplateGeneratorBase # noqa: F401
|
||||||
|
from .system_prompts import ( # noqa: F401
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
FunctionTagCustomToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
PythonListCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
)
|
||||||
|
from .tool_response import ToolResponseGenerator # noqa: F401
|
39
llama_stack/models/llama/llama3/prompt_templates/base.py
Normal file
39
llama_stack/models/llama/llama3/prompt_templates/base.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptTemplate:
|
||||||
|
template: str
|
||||||
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
template = Template(self.template)
|
||||||
|
return template.render(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateGeneratorBase:
|
||||||
|
"""
|
||||||
|
Base class for prompt template generators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[Any]:
|
||||||
|
raise NotImplementedError()
|
|
@ -0,0 +1,311 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
ToolDefinition,
|
||||||
|
ToolParamDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||||
|
|
||||||
|
|
||||||
|
class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: {{ today }}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{"today": datetime.now().strftime("%d %B %Y")},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[Any]:
|
||||||
|
return [None]
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def _tool_breakdown(self, tools: List[ToolDefinition]):
|
||||||
|
builtin_tools, custom_tools = [], []
|
||||||
|
for dfn in tools:
|
||||||
|
if isinstance(dfn.tool_name, BuiltinTool):
|
||||||
|
builtin_tools.append(dfn)
|
||||||
|
else:
|
||||||
|
custom_tools.append(dfn)
|
||||||
|
|
||||||
|
return builtin_tools, custom_tools
|
||||||
|
|
||||||
|
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
builtin_tools, custom_tools = self._tool_breakdown(tools)
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{% if builtin_tools or custom_tools -%}
|
||||||
|
Environment: ipython
|
||||||
|
{% endif -%}
|
||||||
|
{% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%}
|
||||||
|
{% if builtin_tools -%}
|
||||||
|
Tools: {{ builtin_tools | join(", ") | trim -}}
|
||||||
|
{% endif %}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{
|
||||||
|
"builtin_tools": [t.tool_name.value for t in builtin_tools],
|
||||||
|
"custom_tools": custom_tools,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
# builtin tools
|
||||||
|
[
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
|
||||||
|
],
|
||||||
|
# only code interpretor
|
||||||
|
[
|
||||||
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
If none of the function can be used, please say so.
|
||||||
|
Here is a list of functions in JSON format:
|
||||||
|
{% for t in custom_tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": [
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
{
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "{{param.description}}"
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
],
|
||||||
|
"required": {{ required_params | tojson }}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{% endfor %}
|
||||||
|
Return function calls in JSON format.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="trending_songs",
|
||||||
|
description="Returns the trending songs on a Music site",
|
||||||
|
parameters={
|
||||||
|
"n": ToolParamDefinition(
|
||||||
|
param_type="int",
|
||||||
|
description="The number of songs to return",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"genre": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="The genre of the songs to return",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
{% for t in custom_tools %}
|
||||||
|
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set modified_params = t.parameters.copy() -%}
|
||||||
|
{%- for key, value in modified_params.items() -%}
|
||||||
|
{%- if 'default' in value -%}
|
||||||
|
{%- set _ = value.pop('default', None) -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- set tparams = modified_params | tojson -%}
|
||||||
|
Use the function '{{ tname }}' to '{{ tdesc }}':
|
||||||
|
{"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}}
|
||||||
|
|
||||||
|
{% endfor -%}
|
||||||
|
Think very carefully before calling functions.
|
||||||
|
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||||
|
|
||||||
|
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||||
|
|
||||||
|
Reminder:
|
||||||
|
- If looking for real time information use relevant functions before falling back to brave_search
|
||||||
|
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- Only call one function at a time
|
||||||
|
- Put the entire function call reply on one line
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="trending_songs",
|
||||||
|
description="Returns the trending songs on a Music site",
|
||||||
|
parameters={
|
||||||
|
"n": ToolParamDefinition(
|
||||||
|
param_type="int",
|
||||||
|
description="The number of songs to return",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"genre": ToolParamDefinition(
|
||||||
|
param_type="str",
|
||||||
|
description="The genre of the songs to return",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
""".strip("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||||
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
|
return PromptTemplate(
|
||||||
|
system_prompt,
|
||||||
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{% for t in tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": {{ required_params | tojson }},
|
||||||
|
"properties": {
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "{{param.param_type}}",
|
||||||
|
"description": "{{param.description}}"{% if param.default %},
|
||||||
|
"default": "{{param.default}}"{% endif %}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},
|
||||||
|
{% endif -%}
|
||||||
|
{%- endfor %}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.strip("\n"),
|
||||||
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
).render()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get weather info for places",
|
||||||
|
parameters={
|
||||||
|
"city": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The name of the city to get the weather for",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"metric": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResponseGenerator(PromptTemplateGeneratorBase):
|
||||||
|
def gen(
|
||||||
|
self,
|
||||||
|
status: str,
|
||||||
|
stdout: Optional[str] = None,
|
||||||
|
stderr: Optional[str] = None,
|
||||||
|
):
|
||||||
|
assert status in [
|
||||||
|
"success",
|
||||||
|
"failure",
|
||||||
|
], f"status must be 'success' or 'failure'; Got: {status}"
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{% if status == "success" %}completed{% else %}failed{% endif %}
|
||||||
|
{%- if stdout %}
|
||||||
|
[stdout]{{ stdout }}[/stdout]
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if stderr %}
|
||||||
|
[stderr]{{ stderr }}[/stderr]
|
||||||
|
{%- endif -%}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.lstrip("\n"),
|
||||||
|
{
|
||||||
|
"status": status,
|
||||||
|
"stdout": stdout,
|
||||||
|
"stderr": stderr,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_examples(self):
|
||||||
|
return [
|
||||||
|
# success
|
||||||
|
{
|
||||||
|
"status": "success",
|
||||||
|
"stdout": '{"results":["something something"]}',
|
||||||
|
},
|
||||||
|
# failure
|
||||||
|
{
|
||||||
|
"status": "failure",
|
||||||
|
"stderr": "brave_search encounter an error: could not communicate with api.brave.com",
|
||||||
|
},
|
||||||
|
]
|
120
llama_stack/models/llama/llama3/template_data.py
Normal file
120
llama_stack/models/llama/llama3/template_data.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .prompt_templates import (
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
ToolResponseGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
INSTRUCTION = "You are a helpful assistant."
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_builtin_tools_only():
|
||||||
|
return {
|
||||||
|
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
|
||||||
|
"custom_tools": [],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_builtin_code_only():
|
||||||
|
return {
|
||||||
|
"builtin_tools": BuiltinToolGenerator().data_examples()[1],
|
||||||
|
"custom_tools": [],
|
||||||
|
"instruction": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_custom_tools_only():
|
||||||
|
return {
|
||||||
|
"builtin_tools": [],
|
||||||
|
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_message_builtin_and_custom_tools():
|
||||||
|
return {
|
||||||
|
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
|
||||||
|
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def system_default():
|
||||||
|
return {
|
||||||
|
"builtin_tools": [],
|
||||||
|
"custom_tools": [],
|
||||||
|
"instruction": INSTRUCTION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tool_success():
|
||||||
|
return ToolResponseGenerator().data_examples()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def tool_failure():
|
||||||
|
return ToolResponseGenerator().data_examples()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_builtin_tool_call():
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"tool_call": ToolCall(
|
||||||
|
call_id="uuid",
|
||||||
|
tool_name=BuiltinTool.brave_search,
|
||||||
|
arguments={
|
||||||
|
"query": "Who won NBA in 2024?",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"stop_reason": StopReason.end_of_message,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_custom_tool_call():
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"tool_call": ToolCall(
|
||||||
|
call_id="uuid",
|
||||||
|
tool_name="trending_songs",
|
||||||
|
arguments={"country": "US", "n": 10},
|
||||||
|
),
|
||||||
|
"stop_reason": StopReason.end_of_turn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_default():
|
||||||
|
return {
|
||||||
|
"content": "Hi, I am a helpful assistant. What can I help you with today?",
|
||||||
|
"tool_call": None,
|
||||||
|
"stop_reason": StopReason.end_of_turn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def user_default():
|
||||||
|
return {"content": "Please tell me how to plan a trip to New York"}
|
||||||
|
|
||||||
|
|
||||||
|
def user_images():
|
||||||
|
return {"content": "<|image|><|image|>What do these images depict?"}
|
||||||
|
|
||||||
|
|
||||||
|
def user_interleaved_images():
|
||||||
|
return {"content": "<|image|>Describe the image in one sentence.<|image|>Write a haiku about these images"}
|
199
llama_stack/models/llama/llama3/test_system_prompts.py
Normal file
199
llama_stack/models/llama/llama3/test_system_prompts.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
import unittest
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .prompt_templates import (
|
||||||
|
BuiltinToolGenerator,
|
||||||
|
FunctionTagCustomToolGenerator,
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
PythonListCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateTests(unittest.TestCase):
|
||||||
|
def check_generator_output(self, generator, expected_text):
|
||||||
|
example = generator.data_examples()[0]
|
||||||
|
|
||||||
|
pt = generator.gen(example)
|
||||||
|
text = pt.render()
|
||||||
|
# print(text) # debugging
|
||||||
|
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||||
|
|
||||||
|
def test_system_default(self):
|
||||||
|
generator = SystemDefaultGenerator()
|
||||||
|
today = datetime.now().strftime("%d %B %Y")
|
||||||
|
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||||
|
self.check_generator_output(generator, expected_text)
|
||||||
|
|
||||||
|
def test_system_builtin_only(self):
|
||||||
|
generator = BuiltinToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_system_custom_only(self):
|
||||||
|
self.maxDiff = None
|
||||||
|
generator = JsonCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
If none of the function can be used, please say so.
|
||||||
|
Here is a list of functions in JSON format:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "trending_songs",
|
||||||
|
"description": "Returns the trending songs on a Music site",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"n": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The number of songs to return"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"genre": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The genre of the songs to return"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"required": ["n"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Return function calls in JSON format.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_system_custom_function_tag(self):
|
||||||
|
self.maxDiff = None
|
||||||
|
generator = FunctionTagCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||||
|
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||||
|
|
||||||
|
Think very carefully before calling functions.
|
||||||
|
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||||
|
|
||||||
|
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||||
|
|
||||||
|
Reminder:
|
||||||
|
- If looking for real time information use relevant functions before falling back to brave_search
|
||||||
|
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- Only call one function at a time
|
||||||
|
- Put the entire function call reply on one line
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_llama_3_2_system_zero_shot(self):
|
||||||
|
generator = PythonListCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": ["city"],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||||
|
|
||||||
|
def test_llama_3_2_provided_system_prompt(self):
|
||||||
|
generator = PythonListCustomToolGenerator()
|
||||||
|
expected_text = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Overriding message.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": ["city"],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]"""
|
||||||
|
)
|
||||||
|
user_system_prompt = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Overriding message.
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
example = generator.data_examples()[0]
|
||||||
|
|
||||||
|
pt = generator.gen(example, user_system_prompt)
|
||||||
|
text = pt.render()
|
||||||
|
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
12
llama_stack/models/llama/llama3_1/__init__.py
Normal file
12
llama_stack/models/llama/llama3_1/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
259
llama_stack/models/llama/llama3_1/prompts.py
Normal file
259
llama_stack/models/llama/llama3_1/prompts.py
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
# llama3_1_e2e_tool_call_dialog,
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_tool_call_dialog,
|
||||||
|
llama3_1_custom_tool_call_dialog,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wolfram_alpha_response():
|
||||||
|
return textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"queryresult": {
|
||||||
|
"success": true,
|
||||||
|
"inputstring": "100th decimal of pi",
|
||||||
|
"pods": [
|
||||||
|
{
|
||||||
|
"title": "Input interpretation",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "100th digit | \u03c0"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Nearby digits",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Result",
|
||||||
|
"primary": true,
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "7"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def usecases() -> List[UseCase | str]:
|
||||||
|
return [
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
# Llama 3.1 - Prompt Formats
|
||||||
|
## Tokens
|
||||||
|
Here is a list of special tokens that are supported by Llama 3.1:
|
||||||
|
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||||
|
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||||
|
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||||
|
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool]
|
||||||
|
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||||
|
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||||
|
- at the end of a direct interaction between the model and the user
|
||||||
|
- at the end of multiple interactions between the model and any available tools
|
||||||
|
This token signals to the executor that the model has finished generating a response.
|
||||||
|
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
There are 4 different roles that are supported by Llama 3.1
|
||||||
|
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||||
|
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||||
|
- `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".)
|
||||||
|
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Llama 3.1 Base Model",
|
||||||
|
description="Text completion for Llama 3.1 base model uses this format.",
|
||||||
|
dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")],
|
||||||
|
notes="Note start special tag",
|
||||||
|
),
|
||||||
|
"## Llama 3.1 Instruct Model",
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Answer who are you in the form of jeopardy?",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="",
|
||||||
|
),
|
||||||
|
"## Tool Calling Formats",
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||||
|
- Brave Search: Tool call to perform web searches.
|
||||||
|
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||||
|
- Code Interpreter: Enables the model to output python code.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Tool Calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of a conversation using brave search
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||||
|
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||||
|
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||||
|
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Code Interpreter",
|
||||||
|
description="Here is an actual example of model responding with code",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="Environment: ipython"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Write code to check if number is prime, use that to see if the number 7 is prime",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||||
|
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Built-in tools full interaction",
|
||||||
|
description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(
|
||||||
|
role="system",
|
||||||
|
content="Environment: ipython\nTools: brave_search, wolfram_alpha\n",
|
||||||
|
),
|
||||||
|
RawMessage(role="user", content="What is the 100th decimal of pi?"),
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id",
|
||||||
|
tool_name=BuiltinTool.wolfram_alpha,
|
||||||
|
arguments={"query": "100th decimal of pi"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=wolfram_alpha_response(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note the `<|python_tag|>` in the assistant response.
|
||||||
|
- Role is `tool` for the wolfram alpha response that is passed back to the model.
|
||||||
|
- Final message from assistant has <|eot_id|> tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"## Zero shot tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="JSON based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||||
|
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||||
|
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- JSON format for providing tools needs name, description and parameters
|
||||||
|
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
- Only single tool calls are supported as of now
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# FIXME: This is not working yet as expected
|
||||||
|
# UseCase(
|
||||||
|
# title="E2E tool call example",
|
||||||
|
# description=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model.
|
||||||
|
# """
|
||||||
|
# ),
|
||||||
|
# dialogs=[
|
||||||
|
# llama3_1_e2e_tool_call_dialog(
|
||||||
|
# tool_prompt_format=ToolPromptFormat.function_tag
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
# notes="",
|
||||||
|
# ),
|
||||||
|
"## Example of a user defined tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="`<function>` based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||||
|
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
12
llama_stack/models/llama/llama3_2/__init__.py
Normal file
12
llama_stack/models/llama/llama3_2/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
235
llama_stack/models/llama/llama3_2/prompts_text.py
Normal file
235
llama_stack/models/llama/llama3_2/prompts_text.py
Normal file
|
@ -0,0 +1,235 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
import json
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_code_interpreter_dialog,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def user_tool_call():
|
||||||
|
content = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||||
|
Here is a list of functions in JSON format that you can invoke:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_user_info",
|
||||||
|
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"user_id"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"user_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||||
|
},
|
||||||
|
"special": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||||
|
"default": "none"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||||
|
|
||||||
|
NO other text MUST be included.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def system_tool_call():
|
||||||
|
content = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"city"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def usecases():
|
||||||
|
return [
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(role="user", content="Who are you?"),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="This format is unchanged from Llama3.1",
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Zero shot function calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
|
||||||
|
This new format is designed to be more flexible and powerful than the previous format.
|
||||||
|
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
|
||||||
|
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
|
||||||
|
Here is an example for the same,
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
# Zero shot tool calls as system message
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content=system_tool_call()),
|
||||||
|
RawMessage(role="user", content="What is the weather in SF and Seattle?"),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The output supports multiple tool calls natively
|
||||||
|
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Zero shot function calling with user message",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
# Zero shot tool call as user message
|
||||||
|
[
|
||||||
|
RawMessage(role="user", content=user_tool_call()),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||||
|
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Code Interpreter",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
|
||||||
|
Here is an example,
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_code_interpreter_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note `Environment: ipython` in the system prompt.
|
||||||
|
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Zero shot function calling E2E format",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content=system_tool_call()),
|
||||||
|
RawMessage(role="user", content="What is the weather in SF?"),
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="cc",
|
||||||
|
tool_name="get_weather",
|
||||||
|
arguments={
|
||||||
|
"city": "San Francisco",
|
||||||
|
"metric": "celsius",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=json.dumps("25 C"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The output of the function call is provided back to the model as a tool response ( in json format ).
|
||||||
|
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
|
||||||
|
- The model finally summarizes the information from the tool response and returns the result to the user.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
tool_prompt_format=ToolPromptFormat.python_list,
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Prompt format for base models",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"),
|
||||||
|
],
|
||||||
|
notes="Same as Llama3.1",
|
||||||
|
),
|
||||||
|
]
|
133
llama_stack/models/llama/llama3_2/prompts_vision.py
Normal file
133
llama_stack/models/llama/llama3_2/prompts_vision.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
|
RawTextItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_tool_call_dialog,
|
||||||
|
# llama3_1_builtin_tool_call_with_image_dialog,
|
||||||
|
llama3_2_user_assistant_conversation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def usecases():
|
||||||
|
this_dir = Path(__file__).parent.parent.resolve()
|
||||||
|
with open(this_dir / "scripts/resources/dog.jpg", "rb") as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
return [
|
||||||
|
llama3_2_user_assistant_conversation(),
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation with Images",
|
||||||
|
description="This example shows how to pass and image to the model as part of the messages.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
RawMediaItem(data=img),
|
||||||
|
RawTextItem(text="Describe this image in two sentences"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- The `<|image|>` tag is used to indicate presence of the image
|
||||||
|
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
|
||||||
|

|
||||||
|
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
|
||||||
|
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
|
||||||
|
- We recommend using a single image in one prompt
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin and Zero Shot Tool Calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
|
||||||
|
Use `Environment: ipython` to enable tools.
|
||||||
|
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
|
||||||
|
The same builtin tools as Llama3.1 are available,
|
||||||
|
- code_interpreter (for executing python code)
|
||||||
|
- brave_search (to search the web)
|
||||||
|
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note the `<|python_tag|>` before `brave_search` function call.
|
||||||
|
- The `<|eom_id|>` tag is used to indicate the end of the message.
|
||||||
|
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
|
||||||
|
- Tool Calling does NOT work with images in the prompt as of now.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# UseCase(
|
||||||
|
# title="Tool Calling for vision models",
|
||||||
|
# description=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# While Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only,
|
||||||
|
# they are not able to do tool calling when prompt contains image inputs (along with text).
|
||||||
|
# The recommended way would be to separate out the image understanding from the tool calling in successive prompts.
|
||||||
|
# Here is an example of how that could be done,
|
||||||
|
# """,
|
||||||
|
# ),
|
||||||
|
# dialogs=[llama3_1_builtin_tool_call_with_image_dialog()],
|
||||||
|
# notes=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# - Instead of a single prompt (image understanding + tool call), we split into two prompts to achieve the same result.
|
||||||
|
# """
|
||||||
|
# ),
|
||||||
|
# ),
|
||||||
|
UseCase(
|
||||||
|
title="Prompt format for base models",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"),
|
||||||
|
],
|
||||||
|
notes="- Same as Llama3.1",
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Prompt format for base models with Image",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[
|
||||||
|
TextCompletionContent(
|
||||||
|
content=[
|
||||||
|
RawMediaItem(data=img),
|
||||||
|
RawTextItem(text="If I had to write a haiku for this one"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
notes="- Note the placement of the special tags <|begin_of_text|> and <|image|>",
|
||||||
|
),
|
||||||
|
]
|
258
llama_stack/models/llama/llama3_3/prompts.py
Normal file
258
llama_stack/models/llama/llama3_3/prompts.py
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..prompt_format import (
|
||||||
|
# llama3_1_e2e_tool_call_dialog,
|
||||||
|
TextCompletionContent,
|
||||||
|
UseCase,
|
||||||
|
llama3_1_builtin_tool_call_dialog,
|
||||||
|
llama3_1_custom_tool_call_dialog,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wolfram_alpha_response():
|
||||||
|
return textwrap.dedent(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"queryresult": {
|
||||||
|
"success": true,
|
||||||
|
"inputstring": "100th decimal of pi",
|
||||||
|
"pods": [
|
||||||
|
{
|
||||||
|
"title": "Input interpretation",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "100th digit | \u03c0"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Nearby digits",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Result",
|
||||||
|
"primary": true,
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "7"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def usecases() -> List[UseCase | str]:
|
||||||
|
return [
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
# Llama 3.1 - Prompt Formats
|
||||||
|
## Tokens
|
||||||
|
Here is a list of special tokens that are supported by Llama 3.1:
|
||||||
|
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||||
|
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||||
|
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||||
|
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool]
|
||||||
|
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||||
|
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||||
|
- at the end of a direct interaction between the model and the user
|
||||||
|
- at the end of multiple interactions between the model and any available tools
|
||||||
|
This token signals to the executor that the model has finished generating a response.
|
||||||
|
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
There are 4 different roles that are supported by Llama 3.1
|
||||||
|
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||||
|
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||||
|
- `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".)
|
||||||
|
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Llama 3.1 Base Model",
|
||||||
|
description="Text completion for Llama 3.1 base model uses this format.",
|
||||||
|
dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")],
|
||||||
|
notes="Note start special tag",
|
||||||
|
),
|
||||||
|
"## Llama 3.1 Instruct Model",
|
||||||
|
UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Answer who are you in the form of jeopardy?",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="",
|
||||||
|
),
|
||||||
|
"## Tool Calling Formats",
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||||
|
- Brave Search: Tool call to perform web searches.
|
||||||
|
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||||
|
- Code Interpreter: Enables the model to output python code.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Tool Calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of a conversation using brave search
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_builtin_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||||
|
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||||
|
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||||
|
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Builtin Code Interpreter",
|
||||||
|
description="Here is an actual example of model responding with code",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="Environment: ipython"),
|
||||||
|
RawMessage(
|
||||||
|
role="user",
|
||||||
|
content="Write code to check if number is prime, use that to see if the number 7 is prime",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||||
|
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
UseCase(
|
||||||
|
title="Built-in tools full interaction",
|
||||||
|
description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(
|
||||||
|
role="system",
|
||||||
|
content="Environment: ipython\nTools: brave_search, wolfram_alpha\n",
|
||||||
|
),
|
||||||
|
RawMessage(role="user", content="What is the 100th decimal of pi?"),
|
||||||
|
RawMessage(
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id",
|
||||||
|
tool_name=BuiltinTool.wolfram_alpha,
|
||||||
|
arguments={"query": "100th decimal of pi"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
RawMessage(
|
||||||
|
role="tool",
|
||||||
|
content=wolfram_alpha_response(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- Note the `<|python_tag|>` in the assistant response.
|
||||||
|
- Role is `tool` for the wolfram alpha response that is passed back to the model.
|
||||||
|
- Final message from assistant has <|eot_id|> tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"## Zero shot tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="JSON based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||||
|
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||||
|
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog()],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- JSON format for providing tools needs name, description and parameters
|
||||||
|
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
- Only single tool calls are supported as of now
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# FIXME: This is not working yet as expected
|
||||||
|
# UseCase(
|
||||||
|
# title="E2E tool call example",
|
||||||
|
# description=textwrap.dedent(
|
||||||
|
# """
|
||||||
|
# Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model.
|
||||||
|
# """
|
||||||
|
# ),
|
||||||
|
# dialogs=[
|
||||||
|
# llama3_1_e2e_tool_call_dialog(
|
||||||
|
# tool_prompt_format=ToolPromptFormat.function_tag
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
# notes="",
|
||||||
|
# ),
|
||||||
|
"## Example of a user defined tool calling",
|
||||||
|
UseCase(
|
||||||
|
title="`<function>` based tool calling",
|
||||||
|
description=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||||
|
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)],
|
||||||
|
notes=textwrap.dedent(
|
||||||
|
"""
|
||||||
|
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
204
llama_stack/models/llama/prompt_format.py
Normal file
204
llama_stack/models/llama/prompt_format.py
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.datatypes import (
|
||||||
|
RawContent,
|
||||||
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
|
RawTextItem,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .llama3.interface import LLama31Interface
|
||||||
|
from .llama3.template_data import (
|
||||||
|
system_message_builtin_code_only,
|
||||||
|
system_message_builtin_tools_only,
|
||||||
|
system_message_custom_tools_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompletionContent(BaseModel):
|
||||||
|
content: RawContent = ""
|
||||||
|
|
||||||
|
|
||||||
|
class UseCase(BaseModel):
|
||||||
|
title: str = ""
|
||||||
|
description: str = ""
|
||||||
|
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||||
|
notes: str = ""
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||||
|
|
||||||
|
def md_format(self):
|
||||||
|
section = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
## {title}
|
||||||
|
|
||||||
|
{description}
|
||||||
|
|
||||||
|
{dialogs_text}
|
||||||
|
{notes}
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return section.lstrip()
|
||||||
|
|
||||||
|
def dialogs_to_text(self, generator) -> str:
|
||||||
|
def _code_block(text):
|
||||||
|
return f"```\n{text}\n```"
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for dialog in self.dialogs:
|
||||||
|
if isinstance(dialog, str):
|
||||||
|
text += dialog
|
||||||
|
text += "\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif isinstance(dialog, TextCompletionContent):
|
||||||
|
input_tokens, output_tokens = generator.text_completion_raw(
|
||||||
|
dialog.content,
|
||||||
|
max_gen_len=64,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.95,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||||
|
dialog,
|
||||||
|
max_gen_len=512,
|
||||||
|
temperature=0.0,
|
||||||
|
top_p=0.95,
|
||||||
|
tool_prompt_format=self.tool_prompt_format,
|
||||||
|
)
|
||||||
|
text += "##### Input Prompt Format\n"
|
||||||
|
|
||||||
|
# FIXME: This is added to undo the hack in chat_formatter where
|
||||||
|
# vision tokens are replaced with 128256.
|
||||||
|
input_tokens = [generator.formatter.vision_token if t == 128256 else t for t in input_tokens]
|
||||||
|
|
||||||
|
text += _code_block(generator.tokenizer.decode(input_tokens))
|
||||||
|
# TODO: Figure out if "↵" needs to be added for newlines or end or some indication
|
||||||
|
text += "\n\n"
|
||||||
|
text += "##### Model Response Format\n"
|
||||||
|
text += _code_block(generator.tokenizer.decode(output_tokens))
|
||||||
|
text += "\n\n"
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def to_text(self, generator):
|
||||||
|
section = self.md_format()
|
||||||
|
dialogs_text = self.dialogs_to_text(generator)
|
||||||
|
notes = f"##### Notes\n{self.notes}" if self.notes else ""
|
||||||
|
section = section.format(
|
||||||
|
title=self.title,
|
||||||
|
description=self.description,
|
||||||
|
dialogs_text=dialogs_text,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
return section
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_builtin_tools_only())
|
||||||
|
messages += interface.user_message(content="Search the web for the latest price of 1oz gold?")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_builtin_code_only())
|
||||||
|
messages += interface.user_message(
|
||||||
|
content="Write code to check if number is prime. Use it to verify if number 7 is prime"
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_builtin_tool_call_with_image_dialog(
|
||||||
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
|
):
|
||||||
|
this_dir = Path(__file__).parent
|
||||||
|
with open(this_dir / "llama3/dog.jpg", "rb") as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_builtin_tools_only())
|
||||||
|
messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
|
||||||
|
messages += interface.assistant_response_messages(
|
||||||
|
"Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
|
||||||
|
StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_custom_tools_only())
|
||||||
|
messages += interface.user_message(content="Use tools to get latest trending songs")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||||
|
tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
|
||||||
|
interface = LLama31Interface(tool_prompt_format)
|
||||||
|
|
||||||
|
messages = interface.system_messages(**system_message_custom_tools_only())
|
||||||
|
messages += interface.user_message(content="Use tools to get latest trending songs")
|
||||||
|
messages.append(
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="call_id",
|
||||||
|
tool_name="trending_songs",
|
||||||
|
arguments={"n": "10", "genre": "latest"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=tool_response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_2_user_assistant_conversation():
|
||||||
|
return UseCase(
|
||||||
|
title="User and assistant conversation",
|
||||||
|
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||||
|
dialogs=[
|
||||||
|
[
|
||||||
|
RawMessage(role="system", content="You are a helpful assistant"),
|
||||||
|
RawMessage(role="user", content="Who are you?"),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
notes="This format is unchanged from Llama3.1",
|
||||||
|
)
|
1000
llama_stack/models/llama/sku_list.py
Normal file
1000
llama_stack/models/llama/sku_list.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -7,7 +7,6 @@
|
||||||
from typing import Any, List, Optional, Protocol
|
from typing import Any, List, Optional, Protocol
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
|
@ -18,6 +17,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool
|
from llama_stack.apis.tools import Tool
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class ModelsProtocolPrivate(Protocol):
|
class ModelsProtocolPrivate(Protocol):
|
||||||
|
|
|
@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
|
@ -63,6 +62,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
|
@ -8,7 +8,6 @@ import tempfile
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
@ -41,6 +40,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
MEMORY_QUERY_TOOL,
|
MEMORY_QUERY_TOOL,
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,20 +23,13 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from llama_models.datatypes import (
|
|
||||||
GreedySamplingStrategy,
|
|
||||||
SamplingParams,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||||
from llama_models.llama3.api.datatypes import Model
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
CrossAttentionTransformer,
|
CrossAttentionTransformer,
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -47,6 +40,13 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
Model,
|
||||||
|
SamplingParams,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
|
|
|
@ -8,14 +8,6 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
SamplingParams,
|
|
||||||
StopReason,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
|
@ -41,6 +33,13 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
|
|
@ -10,10 +10,10 @@ from functools import partial
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Model
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import Model
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
|
|
|
@ -14,14 +14,14 @@ from typing import Any, Dict, List, Optional
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||||
|
@ -35,6 +34,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
|
|
@ -13,8 +13,6 @@
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.datatypes import Model
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||||
from torchtune.models.llama3 import llama3_tokenizer
|
from torchtune.models.llama3 import llama3_tokenizer
|
||||||
|
@ -24,6 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||||
from torchtune.modules.transforms import Transform
|
from torchtune.modules.transforms import Transform
|
||||||
|
|
||||||
from llama_stack.apis.post_training import DatasetFormat
|
from llama_stack.apis.post_training import DatasetFormat
|
||||||
|
from llama_stack.models.llama.datatypes import Model
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
|
|
|
@ -6,8 +6,6 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
|
@ -27,6 +25,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||||
LoraFinetuningSingleDevice,
|
LoraFinetuningSingleDevice,
|
||||||
)
|
)
|
||||||
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
class TorchtunePostTrainingImpl:
|
class TorchtunePostTrainingImpl:
|
||||||
|
|
|
@ -14,7 +14,6 @@ from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
@ -46,6 +45,7 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.inline.post_training.common.validator import (
|
from llama_stack.providers.inline.post_training.common.validator import (
|
||||||
validate_input_dataset_schema,
|
validate_input_dataset_schema,
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,9 +8,6 @@ import re
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.llama3.api.datatypes import Role
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -26,6 +23,7 @@ from llama_stack.apis.safety import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId, Role
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
KVStoreConfig,
|
KVStoreConfig,
|
||||||
SqliteKVStoreConfig,
|
SqliteKVStoreConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
|
|
@ -7,9 +7,7 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -28,6 +26,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
|
|
|
@ -7,9 +7,10 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatabricksImplConfig(BaseModel):
|
class DatabricksImplConfig(BaseModel):
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
@ -25,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(BaseModel):
|
class FireworksImplConfig(BaseModel):
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
@ -30,6 +29,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GroqConfig(BaseModel):
|
class GroqConfig(BaseModel):
|
||||||
|
|
|
@ -9,9 +9,6 @@ from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
import groq
|
import groq
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
from llama_models.datatypes import SamplingParams
|
|
||||||
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
|
||||||
from llama_models.sku_list import CoreModelId
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -29,6 +26,8 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||||
|
from llama_stack.models.llama.sku_list import CoreModelId
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
|
|
@ -24,7 +24,6 @@ from groq.types.chat.chat_completion_user_message_param import (
|
||||||
)
|
)
|
||||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||||
from groq.types.shared.function_definition import FunctionDefinition
|
from groq.types.shared.function_definition import FunctionDefinition
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
|
@ -44,6 +43,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import ToolParamDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
UnparseableToolCall,
|
UnparseableToolCall,
|
||||||
convert_tool_call,
|
convert_tool_call,
|
||||||
|
|
|
@ -7,9 +7,10 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class NVIDIAConfig(BaseModel):
|
class NVIDIAConfig(BaseModel):
|
||||||
|
|
|
@ -7,9 +7,6 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import SamplingParams
|
|
||||||
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
|
||||||
from llama_models.sku_list import CoreModelId
|
|
||||||
from openai import APIConnectionError, AsyncOpenAI
|
from openai import APIConnectionError, AsyncOpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -28,6 +25,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId, SamplingParams, ToolDefinition, ToolPromptFormat
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
|
|
|
@ -8,17 +8,6 @@ import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
|
||||||
GreedySamplingStrategy,
|
|
||||||
TopKSamplingStrategy,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
)
|
|
||||||
from openai import AsyncStream
|
from openai import AsyncStream
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||||
|
@ -87,6 +76,15 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_content_to_url,
|
convert_image_content_to_url,
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunpodImplConfig(BaseModel):
|
class RunpodImplConfig(BaseModel):
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.models.llama.datatypes import Message
|
||||||
|
|
||||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SambaNovaImplConfig(BaseModel):
|
class SambaNovaImplConfig(BaseModel):
|
||||||
|
|
|
@ -7,12 +7,6 @@
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
|
||||||
CoreModelId,
|
|
||||||
GreedySamplingStrategy,
|
|
||||||
TopKSamplingStrategy,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
@ -23,6 +17,12 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
CoreModelId,
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TGIImplConfig(BaseModel):
|
class TGIImplConfig(BaseModel):
|
||||||
|
|
|
@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -31,6 +30,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TogetherImplConfig(BaseModel):
|
class TogetherImplConfig(BaseModel):
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from together import Together
|
from together import Together
|
||||||
|
@ -29,6 +28,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VLLMInferenceAdapterConfig(BaseModel):
|
class VLLMInferenceAdapterConfig(BaseModel):
|
||||||
|
|
|
@ -7,10 +7,9 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api import StopReason, ToolCall
|
from llama_models.datatypes import StopReason, ToolCall
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
||||||
|
@ -37,6 +36,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
|
|
@ -5,9 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -18,6 +17,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
from .config import BraveSearchToolConfig
|
from .config import BraveSearchToolConfig
|
||||||
|
|
|
@ -4,9 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PGVectorVectorIOConfig(BaseModel):
|
class PGVectorVectorIOConfig(BaseModel):
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QdrantVectorIOConfig(BaseModel):
|
class QdrantVectorIOConfig(BaseModel):
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
|
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
@ -25,6 +23,7 @@ from llama_stack.apis.agents import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||||
from llama_stack.apis.safety import ViolationLevel
|
from llama_stack.apis.safety import ViolationLevel
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, SamplingParams, TopPSamplingStrategy
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue