chore: move all Llama Stack types from llama-models to llama-stack (#1098)

llama-models should have extremely minimal cruft. Its sole purpose
should be didactic -- show the simplest implementation of the llama
models and document the prompt formats, etc.

This PR is the complement to
https://github.com/meta-llama/llama-models/pull/279

## Test Plan

Ensure all `llama` CLI `model` sub-commands work:

```bash
llama model list
llama model download --model-id ...
llama model prompt-format -m ...
```

Ran tests:
```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/
LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/
LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/
```

Create a fresh venv `uv venv && source .venv/bin/activate` and run
`llama stack build --template fireworks --image-type venv` followed by
`llama stack run together --image-type venv` <-- the server runs

Also checked that the OpenAPI generator can run and there is no change
in the generated files as a result.

```bash
cd docs/openapi_generator
sh run_openapi_generator.sh
```
This commit is contained in:
Ashwin Bharambe 2025-02-14 09:10:59 -08:00 committed by GitHub
parent c0ee512980
commit 314ee09ae3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
138 changed files with 8491 additions and 465 deletions

View file

@ -30,6 +30,7 @@ repos:
rev: v0.9.4 rev: v0.9.4
hooks: hooks:
- id: ruff - id: ruff
exclude: ^llama_stack/strong_typing/.*$
- id: ruff-format - id: ruff-format
- repo: https://github.com/adamchainz/blacken-docs - repo: https://github.com/adamchainz/blacken-docs
@ -43,7 +44,13 @@ repos:
rev: 0.5.26 rev: 0.5.26
hooks: hooks:
- id: uv-export - id: uv-export
args: ["--frozen", "--no-hashes", "--no-emit-project"] args: [
"--frozen",
"--no-hashes",
"--no-emit-project",
"--output-file=requirements.txt"
]
files: ^pyproject\.toml$
- id: uv-sync - id: uv-sync
# - repo: https://github.com/pre-commit/mirrors-mypy # - repo: https://github.com/pre-commit/mirrors-mypy

View file

@ -16,18 +16,6 @@ from pathlib import Path
import fire import fire
import ruamel.yaml as yaml import ruamel.yaml as yaml
from llama_models import schema_utils
# We do some monkey-patching to ensure our definitions only use the minimal
# (json_schema_type, webmethod) definitions from the llama_models package. For
# generation though, we need the full definitions and implementations from the
# (json-strong-typing) package.
from .strong_typing.schema import json_schema_type, register_schema
schema_utils.json_schema_type = json_schema_type
schema_utils.register_schema = register_schema
from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402 from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
from llama_stack.distribution.stack import LlamaStack # noqa: E402 from llama_stack.distribution.stack import LlamaStack # noqa: E402

View file

@ -10,9 +10,9 @@ import typing
from dataclasses import make_dataclass from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union from typing import Any, Dict, Set, Union
from ..strong_typing.core import JsonType from llama_stack.strong_typing.core import JsonType
from ..strong_typing.docstring import Docstring, parse_type from llama_stack.strong_typing.docstring import Docstring, parse_type
from ..strong_typing.inspection import ( from llama_stack.strong_typing.inspection import (
is_generic_list, is_generic_list,
is_type_optional, is_type_optional,
is_type_union, is_type_union,
@ -20,15 +20,15 @@ from ..strong_typing.inspection import (
unwrap_optional_type, unwrap_optional_type,
unwrap_union_types, unwrap_union_types,
) )
from ..strong_typing.name import python_type_to_name from llama_stack.strong_typing.name import python_type_to_name
from ..strong_typing.schema import ( from llama_stack.strong_typing.schema import (
get_schema_identifier, get_schema_identifier,
JsonSchemaGenerator, JsonSchemaGenerator,
register_schema, register_schema,
Schema, Schema,
SchemaOptions, SchemaOptions,
) )
from ..strong_typing.serialization import json_dump_string, object_to_json from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
from .operations import ( from .operations import (
EndpointOperation, EndpointOperation,

View file

@ -15,7 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from termcolor import colored from termcolor import colored
from ..strong_typing.inspection import get_signature from llama_stack.strong_typing.inspection import get_signature
def split_prefix( def split_prefix(

View file

@ -9,7 +9,7 @@ import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Union from typing import Any, ClassVar, Dict, List, Optional, Union
from ..strong_typing.schema import JsonType, Schema, StrictJsonType from llama_stack.strong_typing.schema import JsonType, Schema, StrictJsonType
URL = str URL = str

View file

@ -9,7 +9,7 @@ import typing
from pathlib import Path from pathlib import Path
from typing import TextIO from typing import TextIO
from ..strong_typing.schema import object_to_json, StrictJsonType from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from .generator import Generator from .generator import Generator
from .options import Options from .options import Options

View file

@ -19,7 +19,6 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
@ -38,6 +37,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class Attachment(BaseModel): class Attachment(BaseModel):

View file

@ -1,206 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent:
def __init__(
self,
role: Optional[str] = None,
content: str = "",
end: str = "\n",
color="white",
):
self.role = role
self.content = content
self.color = color
self.end = "\n" if end is None else end
def __str__(self):
if self.role is not None:
return f"{self.role}> {self.content}"
else:
return f"{self.content}"
def print(self, flush=True):
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
EventType = AgentTurnResponseEventType
class EventLogger:
async def log(
self,
event_generator,
stream=True,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
):
previous_event_type = None
previous_step_type = None
async for chunk in event_generator:
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield (
chunk,
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
)
continue
event = chunk.event
event_type = event.payload.event_type
if event_type in {
EventType.turn_start.value,
EventType.turn_complete.value,
}:
# Currently not logging any turn realted info
yield event, None
continue
step_type = event.payload.step_type
# handle safety
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
violation = event.payload.step_details.violation
if not violation:
yield (
event,
LogEvent(role=step_type, content="No Violation", color="magenta"),
)
else:
yield (
event,
LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
)
# handle inference
if step_type == StepType.inference:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
yield (
event,
LogEvent(role=step_type, content="", end="", color="yellow"),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
# this is the first time we are getting model inference response
# aka equivalent to step_start for inference. Hence,
# start with "Model>".
if (
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
yield (
event,
LogEvent(role=step_type, content="", end="", color="yellow"),
)
delta = event.payload.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
yield (
event,
LogEvent(
role=None,
content=delta.tool_call,
end="",
color="cyan",
),
)
else:
yield (
event,
LogEvent(
role=None,
content=delta.text,
end="",
color="yellow",
),
)
else:
# step_complete
yield event, LogEvent(role=None, content="")
else:
# Not streaming
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
else:
content = response.content
yield (
event,
LogEvent(
role=step_type,
content=content,
color="yellow",
),
)
# handle tool_execution
if (
step_type == StepType.tool_execution
and
# Only print tool calls and responses at the step_complete event
event_type == EventType.step_complete.value
):
details = event.payload.step_details
for t in details.tool_calls:
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
)
for r in details.tool_responses:
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
)
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
yield (
event,
LogEvent(
role=step_type,
content=content,
color="cyan",
),
)
previous_event_type = event_type
previous_step_type = step_type

View file

@ -6,7 +6,6 @@
from typing import List, Optional, Protocol, runtime_checkable from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type

View file

@ -5,10 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel): class CommonBenchmarkFields(BaseModel):

View file

@ -7,10 +7,11 @@
from enum import Enum from enum import Enum
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from llama_stack.models.llama.datatypes import ToolCall
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type @json_schema_type
class URL(BaseModel): class URL(BaseModel):

View file

@ -7,10 +7,10 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type

View file

@ -5,9 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class Job(BaseModel): class Job(BaseModel):

View file

@ -7,9 +7,10 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class PostTrainingMetric(BaseModel): class PostTrainingMetric(BaseModel):

View file

@ -6,10 +6,11 @@
from typing import Literal, Union from typing import Literal, Union
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type @json_schema_type
class StringType(BaseModel): class StringType(BaseModel):

View file

@ -6,10 +6,10 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type

View file

@ -6,12 +6,12 @@
from typing import Any, Dict, List, Literal, Optional, Protocol from typing import Any, Dict, List, Literal, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonDatasetFields(BaseModel): class CommonDatasetFields(BaseModel):

View file

@ -6,7 +6,7 @@
from enum import Enum from enum import Enum
from llama_models.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type

View file

@ -6,7 +6,6 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type @json_schema_type

View file

@ -17,7 +17,13 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from llama_models.llama3.api.datatypes import ( from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
SamplingParams, SamplingParams,
StopReason, StopReason,
@ -25,14 +31,8 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):

View file

@ -6,9 +6,10 @@
from typing import List, Protocol, runtime_checkable from typing import List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type
class ProviderInfo(BaseModel): class ProviderInfo(BaseModel):

View file

@ -7,11 +7,11 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonModelFields(BaseModel): class CommonModelFields(BaseModel):

View file

@ -8,13 +8,13 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type @json_schema_type

View file

@ -7,12 +7,12 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type

View file

@ -6,10 +6,10 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value # mapping of metric to value
ScoringResultRow = Dict[str, Any] ScoringResultRow = Dict[str, Any]

View file

@ -16,12 +16,12 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated # Perhaps more structure can be imposed on these functions. Maybe they could be associated

View file

@ -6,11 +6,11 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonShieldFields(BaseModel): class CommonShieldFields(BaseModel):

View file

@ -7,10 +7,10 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.schema_utils import json_schema_type, webmethod
class FilteringFunction(Enum): class FilteringFunction(Enum):

View file

@ -17,11 +17,12 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from llama_models.llama3.api.datatypes import Primitive
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Add this constant near the top of the file, after the imports # Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7 DEFAULT_TTL_DAYS = 7

View file

@ -7,12 +7,12 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type @json_schema_type

View file

@ -7,13 +7,13 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime from .rag_tool import RAGToolRuntime

View file

@ -6,11 +6,11 @@
from typing import List, Literal, Optional, Protocol, runtime_checkable from typing import List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type

View file

@ -10,12 +10,12 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel): class Chunk(BaseModel):

View file

@ -16,8 +16,6 @@ from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
import httpx import httpx
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from rich.console import Console from rich.console import Console
from rich.progress import ( from rich.progress import (
@ -31,6 +29,8 @@ from rich.progress import (
from termcolor import cprint from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
class Download(Subcommand): class Download(Subcommand):
@ -454,7 +454,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
# Handle comma-separated model IDs # Handle comma-separated model IDs
model_ids = [model_id.strip() for model_id in args.model_id.split(",")] model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
from llama_models.sku_list import llama_meta_net_info, resolve_model from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import ( from .model.safety_models import (
prompt_guard_download_info, prompt_guard_download_info,

View file

@ -7,11 +7,11 @@
import argparse import argparse
import json import json
from llama_models.sku_list import resolve_model
from termcolor import colored from termcolor import colored
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import resolve_model
class ModelDescribe(Subcommand): class ModelDescribe(Subcommand):

View file

@ -6,10 +6,9 @@
import argparse import argparse
from llama_models.sku_list import all_registered_models
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import all_registered_models
class ModelList(Subcommand): class ModelList(Subcommand):

View file

@ -8,9 +8,8 @@ import argparse
import textwrap import textwrap
from io import StringIO from io import StringIO
from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
class ModelPromptFormat(Subcommand): class ModelPromptFormat(Subcommand):

View file

@ -6,11 +6,11 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.datatypes import SamplingParams
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
class PromptGuardModel(BaseModel): class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""

View file

@ -186,33 +186,3 @@ def extract_async_iterator_type(type_hint):
inner_args = get_args(arg) inner_args = get_args(arg)
return inner_args[0] return inner_args[0]
return None return None
async def example(model: str = None):
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
from llama_stack.apis.inference.event_logger import EventLogger
client_class = create_api_client_class(Inference)
client = client_class("http://localhost:5003")
if not model:
model = "Llama3.2-3B-Instruct"
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
cprint(f"User>{message.content}", "green")
stream = True
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
if __name__ == "__main__":
import asyncio
asyncio.run(example())

View file

@ -0,0 +1,277 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from enum import Enum
from typing import Any, Dict, Literal, Optional, Union
# import all for backwards compatibility
from llama_models.datatypes import * # noqa: F403
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
register_schema(ToolCall)
@json_schema_type
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
@json_schema_type
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema(
Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
],
name="SamplingStrategy",
)
@json_schema_type
class SamplingParams(BaseModel):
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
recommended_sampling_params: Optional[SamplingParams] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.id.name
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

View file

@ -0,0 +1,257 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from pathlib import Path
from typing import List, Optional
from llama_models.datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolPromptFormat,
)
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from termcolor import colored
from llama_stack.models.llama.datatypes import ToolDefinition
from . import template_data
from .prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
SystemDefaultGenerator,
ToolResponseGenerator,
)
THIS_DIR = Path(__file__).parent
class Template:
def __init__(
self,
role,
template_name,
data_provider=None,
notes=None,
):
self.role = role
self.template_name = template_name
self.data_provider = data_provider or ""
self._notes = notes or ""
@property
def notes(self):
default = "↵ represents newline"
notes = default
if self._notes:
notes += "\n"
notes += self._notes
return notes
TEMPLATES = [
Template(
"user",
"user-default",
"user_default",
),
Template(
"user",
"user-images",
"user_images",
),
Template("user", "user-interleaved-images", "user_interleaved_images"),
Template(
"assistant",
"assistant-builtin-tool-call",
"assistant_builtin_tool_call",
"Notice <|python_tag|>",
),
Template(
"assistant",
"assistant-custom-tool-call",
"assistant_custom_tool_call",
"Notice <function=...> format",
),
Template(
"assistant",
"assistant-default",
"assistant_default",
),
Template(
"system",
"system-builtin-and-custom-tools",
"system_message_builtin_and_custom_tools",
),
Template(
"system",
"system-builtin-tools-only",
"system_message_builtin_tools_only",
),
Template(
"system",
"system-custom-tools-only",
"system_message_custom_tools_only",
),
Template(
"system",
"system-default",
"system_default",
),
Template(
"tool",
"tool-success",
"tool_success",
"Note ipython header and [stdout]",
),
Template(
"tool",
"tool-failure",
"tool_failure",
"Note ipython header and [stderr]",
),
]
class LLama31Interface:
def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json):
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
self.tool_prompt_format = tool_prompt_format
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
model_input = self.formatter.encode_dialog_prompt(
messages,
self.tool_prompt_format,
)
return model_input.tokens
def tool_response_messages(self, *args, **kwargs):
template = ToolResponseGenerator().gen(*args, **kwargs)
return [
RawMessage(
role="tool",
content=template.render(),
)
]
def system_messages(
self,
builtin_tools: List[BuiltinTool],
custom_tools: List[ToolDefinition],
instruction: Optional[str] = None,
) -> List[RawMessage]:
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if builtin_tools or custom_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools + custom_tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if instruction:
sys_content += "\n\n"
sys_content += instruction
sys_content += "\n"
messages.append(RawMessage(role="system", content=sys_content))
if custom_tools:
if self.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
custom_template = tool_gen.gen(custom_tools)
messages.append(RawMessage(role="user", content=custom_template.render()))
return messages
def assistant_response_messages(
self,
content: str,
stop_reason: StopReason,
tool_call: Optional[ToolCall] = None,
) -> List[RawMessage]:
tool_calls = []
if tool_call:
tool_calls.append(tool_call)
return [
RawMessage(
role="assistant",
content=content,
tool_calls=tool_calls,
stop_reason=stop_reason,
)
]
def user_message(self, content: str) -> List[RawMessage]:
return [RawMessage(role="user", content=content)]
def display_message_as_tokens(self, message: RawMessage) -> None:
"""Util to print tokenized string to shell"""
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
on_colors = [
"on_red",
"on_green",
"on_yellow",
"on_blue",
"on_magenta",
"on_cyan",
]
for i, t in enumerate(tokens):
on_col = on_colors[i % len(on_colors)]
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
print("\n", end="")
def list_jinja_templates() -> List[Template]:
return TEMPLATES
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
by_name = {t.template_name: t for t in TEMPLATES}
if name not in by_name:
raise ValueError(f"No template found for `{name}`")
template = by_name[name]
interface = LLama31Interface(tool_prompt_format)
data_func = getattr(template_data, template.data_provider)
if template.role == "system":
messages = interface.system_messages(**data_func())
elif template.role == "tool":
messages = interface.tool_response_messages(**data_func())
elif template.role == "assistant":
messages = interface.assistant_response_messages(**data_func())
elif template.role == "user":
messages = interface.user_message(**data_func())
tokens = interface.get_tokens(messages)
special_tokens = list(interface.tokenizer.special_tokens.values())
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
return template, tokens

Binary file not shown.

After

Width:  |  Height:  |  Size: 438 KiB

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from .base import PromptTemplate, PromptTemplateGeneratorBase # noqa: F401
from .system_prompts import ( # noqa: F401
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from .tool_response import ToolResponseGenerator # noqa: F401

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from typing import Any, Dict, List
from jinja2 import Template
@dataclass
class PromptTemplate:
template: str
data: Dict[str, Any]
def render(self):
template = Template(self.template)
return template.render(self.data)
class PromptTemplateGeneratorBase:
"""
Base class for prompt template generators.
"""
def gen(self, *args, **kwargs) -> PromptTemplate:
raise NotImplementedError()
def data_examples(self) -> List[Any]:
raise NotImplementedError()

View file

@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from datetime import datetime
from typing import Any, List, Optional
from llama_models.datatypes import (
BuiltinTool,
)
from llama_stack.models.llama.datatypes import (
ToolDefinition,
ToolParamDefinition,
)
from .base import PromptTemplate, PromptTemplateGeneratorBase
class SystemDefaultGenerator(PromptTemplateGeneratorBase):
def gen(self, *args, **kwargs) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Cutting Knowledge Date: December 2023
Today Date: {{ today }}
"""
)
return PromptTemplate(
template_str.lstrip("\n"),
{"today": datetime.now().strftime("%d %B %Y")},
)
def data_examples(self) -> List[Any]:
return [None]
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
def _tool_breakdown(self, tools: List[ToolDefinition]):
builtin_tools, custom_tools = [], []
for dfn in tools:
if isinstance(dfn.tool_name, BuiltinTool):
builtin_tools.append(dfn)
else:
custom_tools.append(dfn)
return builtin_tools, custom_tools
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
builtin_tools, custom_tools = self._tool_breakdown(tools)
template_str = textwrap.dedent(
"""
{% if builtin_tools or custom_tools -%}
Environment: ipython
{% endif -%}
{% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%}
{% if builtin_tools -%}
Tools: {{ builtin_tools | join(", ") | trim -}}
{% endif %}
"""
)
return PromptTemplate(
template_str.lstrip("\n"),
{
"builtin_tools": [t.tool_name.value for t in builtin_tools],
"custom_tools": custom_tools,
},
)
def data_examples(self) -> List[List[ToolDefinition]]:
return [
# builtin tools
[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
],
# only code interpretor
[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
]
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{% for t in custom_tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
{%- for name, param in tparams.items() if param.required == true -%}
{%- set _ = required_params.append(name) -%}
{%- endfor -%}
{
"type": "function",
"function": {
"name": "{{tname}}",
"description": "{{tdesc}}",
"parameters": {
"type": "object",
"properties": [
{%- for name, param in tparams.items() %}
{
"{{name}}": {
"type": "object",
"description": "{{param.description}}"
}
}{% if not loop.last %},{% endif %}
{%- endfor %}
],
"required": {{ required_params | tojson }}
}
}
}
{% endfor %}
Return function calls in JSON format.
"""
)
return PromptTemplate(
template_str.lstrip("\n"),
{"custom_tools": [t.model_dump() for t in custom_tools]},
)
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
ToolDefinition(
tool_name="trending_songs",
description="Returns the trending songs on a Music site",
parameters={
"n": ToolParamDefinition(
param_type="int",
description="The number of songs to return",
required=True,
),
"genre": ToolParamDefinition(
param_type="str",
description="The genre of the songs to return",
required=False,
),
},
),
]
]
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
You have access to the following functions:
{% for t in custom_tools %}
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set modified_params = t.parameters.copy() -%}
{%- for key, value in modified_params.items() -%}
{%- if 'default' in value -%}
{%- set _ = value.pop('default', None) -%}
{%- endif -%}
{%- endfor -%}
{%- set tparams = modified_params | tojson -%}
Use the function '{{ tname }}' to '{{ tdesc }}':
{"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}}
{% endfor -%}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
)
return PromptTemplate(
template_str.lstrip("\n"),
{"custom_tools": [t.model_dump() for t in custom_tools]},
)
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
ToolDefinition(
tool_name="trending_songs",
description="Returns the trending songs on a Music site",
parameters={
"n": ToolParamDefinition(
param_type="int",
description="The number of songs to return",
required=True,
),
"genre": ToolParamDefinition(
param_type="str",
description="The genre of the songs to return",
required=False,
),
},
),
]
]
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
DEFAULT_PROMPT = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
{{ function_description }}
""".strip("\n")
)
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{% for t in tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
{%- for name, param in tparams.items() if param.required == true -%}
{%- set _ = required_params.append(name) -%}
{%- endfor -%}
{
"name": "{{tname}}",
"description": "{{tdesc}}",
"parameters": {
"type": "dict",
"required": {{ required_params | tojson }},
"properties": {
{%- for name, param in tparams.items() %}
"{{name}}": {
"type": "{{param.param_type}}",
"description": "{{param.description}}"{% if param.default %},
"default": "{{param.default}}"{% endif %}
}{% if not loop.last %},{% endif %}
{%- endfor %}
}
}
}{% if not loop.last %},
{% endif -%}
{%- endfor %}
]
"""
)
return PromptTemplate(
template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]},
).render()
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
ToolDefinition(
tool_name="get_weather",
description="Get weather info for places",
parameters={
"city": ToolParamDefinition(
param_type="string",
description="The name of the city to get the weather for",
required=True,
),
"metric": ToolParamDefinition(
param_type="string",
description="The metric for weather. Options are: celsius, fahrenheit",
required=False,
default="celsius",
),
},
),
]
]

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from typing import Optional
from .base import PromptTemplate, PromptTemplateGeneratorBase
class ToolResponseGenerator(PromptTemplateGeneratorBase):
def gen(
self,
status: str,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
):
assert status in [
"success",
"failure",
], f"status must be 'success' or 'failure'; Got: {status}"
template_str = textwrap.dedent(
"""
{% if status == "success" %}completed{% else %}failed{% endif %}
{%- if stdout %}
[stdout]{{ stdout }}[/stdout]
{%- endif -%}
{%- if stderr %}
[stderr]{{ stderr }}[/stderr]
{%- endif -%}
"""
)
return PromptTemplate(
template_str.lstrip("\n"),
{
"status": status,
"stdout": stdout,
"stderr": stderr,
},
)
def data_examples(self):
return [
# success
{
"status": "success",
"stdout": '{"results":["something something"]}',
},
# failure
{
"status": "failure",
"stderr": "brave_search encounter an error: could not communicate with api.brave.com",
},
]

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from llama_models.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
)
from .prompt_templates import (
BuiltinToolGenerator,
JsonCustomToolGenerator,
ToolResponseGenerator,
)
INSTRUCTION = "You are a helpful assistant."
def system_message_builtin_tools_only():
return {
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
"custom_tools": [],
"instruction": INSTRUCTION,
}
def system_message_builtin_code_only():
return {
"builtin_tools": BuiltinToolGenerator().data_examples()[1],
"custom_tools": [],
"instruction": "",
}
def system_message_custom_tools_only():
return {
"builtin_tools": [],
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
"instruction": INSTRUCTION,
}
def system_message_builtin_and_custom_tools():
return {
"builtin_tools": BuiltinToolGenerator().data_examples()[0],
"custom_tools": JsonCustomToolGenerator().data_examples()[0],
"instruction": INSTRUCTION,
}
def system_default():
return {
"builtin_tools": [],
"custom_tools": [],
"instruction": INSTRUCTION,
}
def tool_success():
return ToolResponseGenerator().data_examples()[0]
def tool_failure():
return ToolResponseGenerator().data_examples()[1]
def assistant_builtin_tool_call():
return {
"content": "",
"tool_call": ToolCall(
call_id="uuid",
tool_name=BuiltinTool.brave_search,
arguments={
"query": "Who won NBA in 2024?",
},
),
"stop_reason": StopReason.end_of_message,
}
def assistant_custom_tool_call():
return {
"content": "",
"tool_call": ToolCall(
call_id="uuid",
tool_name="trending_songs",
arguments={"country": "US", "n": 10},
),
"stop_reason": StopReason.end_of_turn,
}
def assistant_default():
return {
"content": "Hi, I am a helpful assistant. What can I help you with today?",
"tool_call": None,
"stop_reason": StopReason.end_of_turn,
}
def user_default():
return {"content": "Please tell me how to plan a trip to New York"}
def user_images():
return {"content": "<|image|><|image|>What do these images depict?"}
def user_interleaved_images():
return {"content": "<|image|>Describe the image in one sentence.<|image|>Write a haiku about these images"}

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
import unittest
from datetime import datetime
from .prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
class PromptTemplateTests(unittest.TestCase):
def check_generator_output(self, generator, expected_text):
example = generator.data_examples()[0]
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
def test_system_default(self):
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
self.check_generator_output(generator, expected_text)
def test_system_builtin_only(self):
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_system_custom_only(self):
self.maxDiff = None
generator = JsonCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{
"type": "function",
"function": {
"name": "trending_songs",
"description": "Returns the trending songs on a Music site",
"parameters": {
"type": "object",
"properties": [
{
"n": {
"type": "object",
"description": "The number of songs to return"
}
},
{
"genre": {
"type": "object",
"description": "The genre of the songs to return"
}
}
],
"required": ["n"]
}
}
}
Return function calls in JSON format.
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_system_custom_function_tag(self):
self.maxDiff = None
generator = FunctionTagCustomToolGenerator()
expected_text = textwrap.dedent(
"""
You have access to the following functions:
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_llama_3_2_system_zero_shot(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))
def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Overriding message.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]"""
)
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -0,0 +1,259 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from typing import List
from llama_models.datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolPromptFormat,
)
from ..prompt_format import (
# llama3_1_e2e_tool_call_dialog,
TextCompletionContent,
UseCase,
llama3_1_builtin_tool_call_dialog,
llama3_1_custom_tool_call_dialog,
)
def wolfram_alpha_response():
return textwrap.dedent(
"""
{
"queryresult": {
"success": true,
"inputstring": "100th decimal of pi",
"pods": [
{
"title": "Input interpretation",
"subpods": [
{
"title": "",
"plaintext": "100th digit | \u03c0"
}
]
},
{
"title": "Nearby digits",
"subpods": [
{
"title": "",
"plaintext": "...86208998628034825342117067982148086513282306647093..."
}
]
},
{
"title": "Result",
"primary": true,
"subpods": [
{
"title": "",
"plaintext": "7"
}
]
}
]
}
}
"""
)
def usecases() -> List[UseCase | str]:
return [
textwrap.dedent(
"""
# Llama 3.1 - Prompt Formats
## Tokens
Here is a list of special tokens that are supported by Llama 3.1:
- `<|begin_of_text|>`: Specifies the start of the prompt
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool]
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
- at the end of a direct interaction between the model and the user
- at the end of multiple interactions between the model and any available tools
This token signals to the executor that the model has finished generating a response.
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
"""
),
textwrap.dedent(
"""
There are 4 different roles that are supported by Llama 3.1
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".)
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
"""
),
UseCase(
title="Llama 3.1 Base Model",
description="Text completion for Llama 3.1 base model uses this format.",
dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")],
notes="Note start special tag",
),
"## Llama 3.1 Instruct Model",
UseCase(
title="User and assistant conversation",
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
dialogs=[
[
RawMessage(role="system", content="You are a helpful assistant"),
RawMessage(
role="user",
content="Answer who are you in the form of jeopardy?",
),
]
],
notes="",
),
"## Tool Calling Formats",
textwrap.dedent(
"""
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
- Brave Search: Tool call to perform web searches.
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
- Code Interpreter: Enables the model to output python code.
"""
),
UseCase(
title="Builtin Tool Calling",
description=textwrap.dedent(
"""
Here is an example of a conversation using brave search
"""
),
dialogs=[llama3_1_builtin_tool_call_dialog()],
notes=textwrap.dedent(
"""
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
- The message body of the assistant response starts with a special tag <|python_tag|>
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
"""
),
),
UseCase(
title="Builtin Code Interpreter",
description="Here is an actual example of model responding with code",
dialogs=[
[
RawMessage(role="system", content="Environment: ipython"),
RawMessage(
role="user",
content="Write code to check if number is prime, use that to see if the number 7 is prime",
),
],
],
notes=textwrap.dedent(
"""
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
"""
),
),
UseCase(
title="Built-in tools full interaction",
description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.",
dialogs=[
[
RawMessage(
role="system",
content="Environment: ipython\nTools: brave_search, wolfram_alpha\n",
),
RawMessage(role="user", content="What is the 100th decimal of pi?"),
RawMessage(
role="assistant",
content="",
stop_reason=StopReason.end_of_message,
tool_calls=[
ToolCall(
call_id="tool_call_id",
tool_name=BuiltinTool.wolfram_alpha,
arguments={"query": "100th decimal of pi"},
)
],
),
RawMessage(
role="tool",
content=wolfram_alpha_response(),
),
],
],
notes=textwrap.dedent(
"""
- Note the `<|python_tag|>` in the assistant response.
- Role is `tool` for the wolfram alpha response that is passed back to the model.
- Final message from assistant has <|eot_id|> tag.
"""
),
),
"## Zero shot tool calling",
UseCase(
title="JSON based tool calling",
description=textwrap.dedent(
"""
Llama models can now output custom tool calls from a single message to allow easier tool calling.
The following prompts provide an example of how custom tools can be called from the output of the model.
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
"""
),
dialogs=[llama3_1_custom_tool_call_dialog()],
notes=textwrap.dedent(
"""
- JSON format for providing tools needs name, description and parameters
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
- Instructions for tools added as a user message
- Only single tool calls are supported as of now
"""
),
),
# FIXME: This is not working yet as expected
# UseCase(
# title="E2E tool call example",
# description=textwrap.dedent(
# """
# Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model.
# """
# ),
# dialogs=[
# llama3_1_e2e_tool_call_dialog(
# tool_prompt_format=ToolPromptFormat.function_tag
# )
# ],
# notes="",
# ),
"## Example of a user defined tool calling",
UseCase(
title="`<function>` based tool calling",
description=textwrap.dedent(
"""
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
In this example, we define a custom tool calling format using the `<function>` tag.
"""
),
dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)],
notes=textwrap.dedent(
"""
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
- Instructions for tools added as a user message
"""
),
),
]

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -0,0 +1,235 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import textwrap
from llama_models.datatypes import (
RawMessage,
StopReason,
ToolCall,
ToolPromptFormat,
)
from ..prompt_format import (
TextCompletionContent,
UseCase,
llama3_1_builtin_code_interpreter_dialog,
)
def user_tool_call():
content = textwrap.dedent(
"""
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
Here is a list of functions in JSON format that you can invoke:
[
{
"name": "get_user_info",
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
"parameters": {
"type": "dict",
"required": [
"user_id"
],
"properties": {
"user_id": {
"type": "integer",
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
},
"special": {
"type": "string",
"description": "Any special information or parameters that need to be considered while fetching user details.",
"default": "none"
}
}
}
}
]
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
NO other text MUST be included.
"""
)
return content.strip()
def system_tool_call():
content = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]
"""
)
return content.strip()
def usecases():
return [
UseCase(
title="User and assistant conversation",
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
dialogs=[
[
RawMessage(role="system", content="You are a helpful assistant"),
RawMessage(role="user", content="Who are you?"),
]
],
notes="This format is unchanged from Llama3.1",
),
UseCase(
title="Zero shot function calling",
description=textwrap.dedent(
"""
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
This new format is designed to be more flexible and powerful than the previous format.
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
Here is an example for the same,
"""
),
dialogs=[
# Zero shot tool calls as system message
[
RawMessage(role="system", content=system_tool_call()),
RawMessage(role="user", content="What is the weather in SF and Seattle?"),
],
],
notes=textwrap.dedent(
"""
- The output supports multiple tool calls natively
- JSON format for defining the functions in the system prompt is similar to Llama3.1
"""
),
),
UseCase(
title="Zero shot function calling with user message",
description=textwrap.dedent(
"""
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
"""
),
dialogs=[
# Zero shot tool call as user message
[
RawMessage(role="user", content=user_tool_call()),
],
],
notes=textwrap.dedent(
"""
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
"""
),
),
UseCase(
title="Code Interpreter",
description=textwrap.dedent(
"""
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
Here is an example,
"""
),
dialogs=[llama3_1_builtin_code_interpreter_dialog()],
notes=textwrap.dedent(
"""
- Note `Environment: ipython` in the system prompt.
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
"""
),
),
UseCase(
title="Zero shot function calling E2E format",
description=textwrap.dedent(
"""
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
"""
),
dialogs=[
[
RawMessage(role="system", content=system_tool_call()),
RawMessage(role="user", content="What is the weather in SF?"),
RawMessage(
role="assistant",
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="cc",
tool_name="get_weather",
arguments={
"city": "San Francisco",
"metric": "celsius",
},
)
],
),
RawMessage(
role="tool",
content=json.dumps("25 C"),
),
],
],
notes=textwrap.dedent(
"""
- The output of the function call is provided back to the model as a tool response ( in json format ).
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
- The model finally summarizes the information from the tool response and returns the result to the user.
"""
),
tool_prompt_format=ToolPromptFormat.python_list,
),
UseCase(
title="Prompt format for base models",
description=textwrap.dedent(
"""
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
"""
),
dialogs=[
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"),
],
notes="Same as Llama3.1",
),
]

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from pathlib import Path
from llama_models.datatypes import (
RawMediaItem,
RawMessage,
RawTextItem,
)
from ..prompt_format import (
TextCompletionContent,
UseCase,
llama3_1_builtin_tool_call_dialog,
# llama3_1_builtin_tool_call_with_image_dialog,
llama3_2_user_assistant_conversation,
)
def usecases():
this_dir = Path(__file__).parent.parent.resolve()
with open(this_dir / "scripts/resources/dog.jpg", "rb") as f:
img = f.read()
return [
llama3_2_user_assistant_conversation(),
UseCase(
title="User and assistant conversation with Images",
description="This example shows how to pass and image to the model as part of the messages.",
dialogs=[
[
RawMessage(
role="user",
content=[
RawMediaItem(data=img),
RawTextItem(text="Describe this image in two sentences"),
],
)
],
],
notes=textwrap.dedent(
"""
- The `<|image|>` tag is used to indicate presence of the image
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
![Image](mm-model.png)
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
- We recommend using a single image in one prompt
"""
),
),
UseCase(
title="Builtin and Zero Shot Tool Calling",
description=textwrap.dedent(
"""
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
Use `Environment: ipython` to enable tools.
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
The same builtin tools as Llama3.1 are available,
- code_interpreter (for executing python code)
- brave_search (to search the web)
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
""",
),
dialogs=[llama3_1_builtin_tool_call_dialog()],
notes=textwrap.dedent(
"""
- Note the `<|python_tag|>` before `brave_search` function call.
- The `<|eom_id|>` tag is used to indicate the end of the message.
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
- Tool Calling does NOT work with images in the prompt as of now.
"""
),
),
# UseCase(
# title="Tool Calling for vision models",
# description=textwrap.dedent(
# """
# While Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only,
# they are not able to do tool calling when prompt contains image inputs (along with text).
# The recommended way would be to separate out the image understanding from the tool calling in successive prompts.
# Here is an example of how that could be done,
# """,
# ),
# dialogs=[llama3_1_builtin_tool_call_with_image_dialog()],
# notes=textwrap.dedent(
# """
# - Instead of a single prompt (image understanding + tool call), we split into two prompts to achieve the same result.
# """
# ),
# ),
UseCase(
title="Prompt format for base models",
description=textwrap.dedent(
"""
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
"""
),
dialogs=[
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"),
],
notes="- Same as Llama3.1",
),
UseCase(
title="Prompt format for base models with Image",
description=textwrap.dedent(
"""
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
"""
),
dialogs=[
TextCompletionContent(
content=[
RawMediaItem(data=img),
RawTextItem(text="If I had to write a haiku for this one"),
]
),
],
notes="- Note the placement of the special tags <|begin_of_text|> and <|image|>",
),
]

View file

@ -0,0 +1,258 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from typing import List
from llama_models.datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolPromptFormat,
)
from ..prompt_format import (
# llama3_1_e2e_tool_call_dialog,
TextCompletionContent,
UseCase,
llama3_1_builtin_tool_call_dialog,
llama3_1_custom_tool_call_dialog,
)
def wolfram_alpha_response():
return textwrap.dedent(
"""
{
"queryresult": {
"success": true,
"inputstring": "100th decimal of pi",
"pods": [
{
"title": "Input interpretation",
"subpods": [
{
"title": "",
"plaintext": "100th digit | \u03c0"
}
]
},
{
"title": "Nearby digits",
"subpods": [
{
"title": "",
"plaintext": "...86208998628034825342117067982148086513282306647093..."
}
]
},
{
"title": "Result",
"primary": true,
"subpods": [
{
"title": "",
"plaintext": "7"
}
]
}
]
}
}
"""
)
def usecases() -> List[UseCase | str]:
return [
textwrap.dedent(
"""
# Llama 3.1 - Prompt Formats
## Tokens
Here is a list of special tokens that are supported by Llama 3.1:
- `<|begin_of_text|>`: Specifies the start of the prompt
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool]
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
- at the end of a direct interaction between the model and the user
- at the end of multiple interactions between the model and any available tools
This token signals to the executor that the model has finished generating a response.
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
"""
),
textwrap.dedent(
"""
There are 4 different roles that are supported by Llama 3.1
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".)
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
"""
),
UseCase(
title="Llama 3.1 Base Model",
description="Text completion for Llama 3.1 base model uses this format.",
dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")],
notes="Note start special tag",
),
"## Llama 3.1 Instruct Model",
UseCase(
title="User and assistant conversation",
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
dialogs=[
[
RawMessage(role="system", content="You are a helpful assistant"),
RawMessage(
role="user",
content="Answer who are you in the form of jeopardy?",
),
]
],
notes="",
),
"## Tool Calling Formats",
textwrap.dedent(
"""
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
- Brave Search: Tool call to perform web searches.
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
- Code Interpreter: Enables the model to output python code.
"""
),
UseCase(
title="Builtin Tool Calling",
description=textwrap.dedent(
"""
Here is an example of a conversation using brave search
"""
),
dialogs=[llama3_1_builtin_tool_call_dialog()],
notes=textwrap.dedent(
"""
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
- The message body of the assistant response starts with a special tag <|python_tag|>
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
"""
),
),
UseCase(
title="Builtin Code Interpreter",
description="Here is an actual example of model responding with code",
dialogs=[
[
RawMessage(role="system", content="Environment: ipython"),
RawMessage(
role="user",
content="Write code to check if number is prime, use that to see if the number 7 is prime",
),
],
],
notes=textwrap.dedent(
"""
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
"""
),
),
UseCase(
title="Built-in tools full interaction",
description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.",
dialogs=[
[
RawMessage(
role="system",
content="Environment: ipython\nTools: brave_search, wolfram_alpha\n",
),
RawMessage(role="user", content="What is the 100th decimal of pi?"),
RawMessage(
content="",
stop_reason=StopReason.end_of_message,
tool_calls=[
ToolCall(
call_id="tool_call_id",
tool_name=BuiltinTool.wolfram_alpha,
arguments={"query": "100th decimal of pi"},
)
],
),
RawMessage(
role="tool",
content=wolfram_alpha_response(),
),
],
],
notes=textwrap.dedent(
"""
- Note the `<|python_tag|>` in the assistant response.
- Role is `tool` for the wolfram alpha response that is passed back to the model.
- Final message from assistant has <|eot_id|> tag.
"""
),
),
"## Zero shot tool calling",
UseCase(
title="JSON based tool calling",
description=textwrap.dedent(
"""
Llama models can now output custom tool calls from a single message to allow easier tool calling.
The following prompts provide an example of how custom tools can be called from the output of the model.
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
"""
),
dialogs=[llama3_1_custom_tool_call_dialog()],
notes=textwrap.dedent(
"""
- JSON format for providing tools needs name, description and parameters
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
- Instructions for tools added as a user message
- Only single tool calls are supported as of now
"""
),
),
# FIXME: This is not working yet as expected
# UseCase(
# title="E2E tool call example",
# description=textwrap.dedent(
# """
# Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model.
# """
# ),
# dialogs=[
# llama3_1_e2e_tool_call_dialog(
# tool_prompt_format=ToolPromptFormat.function_tag
# )
# ],
# notes="",
# ),
"## Example of a user defined tool calling",
UseCase(
title="`<function>` based tool calling",
description=textwrap.dedent(
"""
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
In this example, we define a custom tool calling format using the `<function>` tag.
"""
),
dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)],
notes=textwrap.dedent(
"""
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
- Instructions for tools added as a user message
"""
),
),
]

View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import textwrap
from pathlib import Path
from typing import List
from llama_models.datatypes import (
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
StopReason,
ToolCall,
ToolPromptFormat,
)
from pydantic import BaseModel, Field
from .llama3.interface import LLama31Interface
from .llama3.template_data import (
system_message_builtin_code_only,
system_message_builtin_tools_only,
system_message_custom_tools_only,
)
class TextCompletionContent(BaseModel):
content: RawContent = ""
class UseCase(BaseModel):
title: str = ""
description: str = ""
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
notes: str = ""
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
def md_format(self):
section = textwrap.dedent(
"""
## {title}
{description}
{dialogs_text}
{notes}
"""
)
return section.lstrip()
def dialogs_to_text(self, generator) -> str:
def _code_block(text):
return f"```\n{text}\n```"
text = ""
for dialog in self.dialogs:
if isinstance(dialog, str):
text += dialog
text += "\n\n"
continue
elif isinstance(dialog, TextCompletionContent):
input_tokens, output_tokens = generator.text_completion_raw(
dialog.content,
max_gen_len=64,
temperature=0.1,
top_p=0.95,
)
else:
input_tokens, output_tokens = generator.chat_completion_raw(
dialog,
max_gen_len=512,
temperature=0.0,
top_p=0.95,
tool_prompt_format=self.tool_prompt_format,
)
text += "##### Input Prompt Format\n"
# FIXME: This is added to undo the hack in chat_formatter where
# vision tokens are replaced with 128256.
input_tokens = [generator.formatter.vision_token if t == 128256 else t for t in input_tokens]
text += _code_block(generator.tokenizer.decode(input_tokens))
# TODO: Figure out if "↵" needs to be added for newlines or end or some indication
text += "\n\n"
text += "##### Model Response Format\n"
text += _code_block(generator.tokenizer.decode(output_tokens))
text += "\n\n"
return text
def to_text(self, generator):
section = self.md_format()
dialogs_text = self.dialogs_to_text(generator)
notes = f"##### Notes\n{self.notes}" if self.notes else ""
section = section.format(
title=self.title,
description=self.description,
dialogs_text=dialogs_text,
notes=notes,
)
return section
def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
interface = LLama31Interface(tool_prompt_format)
messages = interface.system_messages(**system_message_builtin_tools_only())
messages += interface.user_message(content="Search the web for the latest price of 1oz gold?")
return messages
def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat.json):
interface = LLama31Interface(tool_prompt_format)
messages = interface.system_messages(**system_message_builtin_code_only())
messages += interface.user_message(
content="Write code to check if number is prime. Use it to verify if number 7 is prime"
)
return messages
def llama3_1_builtin_tool_call_with_image_dialog(
tool_prompt_format=ToolPromptFormat.json,
):
this_dir = Path(__file__).parent
with open(this_dir / "llama3/dog.jpg", "rb") as f:
img = f.read()
interface = LLama31Interface(tool_prompt_format)
messages = interface.system_messages(**system_message_builtin_tools_only())
messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
messages += interface.assistant_response_messages(
"Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
StopReason.end_of_turn,
)
messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
return messages
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
interface = LLama31Interface(tool_prompt_format)
messages = interface.system_messages(**system_message_custom_tools_only())
messages += interface.user_message(content="Use tools to get latest trending songs")
return messages
def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
interface = LLama31Interface(tool_prompt_format)
messages = interface.system_messages(**system_message_custom_tools_only())
messages += interface.user_message(content="Use tools to get latest trending songs")
messages.append(
RawMessage(
role="assistant",
content="",
stop_reason=StopReason.end_of_message,
tool_calls=[
ToolCall(
call_id="call_id",
tool_name="trending_songs",
arguments={"n": "10", "genre": "latest"},
)
],
),
)
messages.append(
RawMessage(
role="assistant",
content=tool_response,
)
)
return messages
def llama3_2_user_assistant_conversation():
return UseCase(
title="User and assistant conversation",
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
dialogs=[
[
RawMessage(role="system", content="You are a helpful assistant"),
RawMessage(role="user", content="Who are you?"),
]
],
notes="This format is unchanged from Llama3.1",
)

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,6 @@
from typing import Any, List, Optional, Protocol from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse from urllib.parse import urlparse
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.benchmarks import Benchmark
@ -18,6 +17,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool from llama_stack.apis.tools import Tool
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.schema_utils import json_schema_type
class ModelsProtocolPrivate(Protocol): class ModelsProtocolPrivate(Protocol):

View file

@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
@ -63,6 +62,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing

View file

@ -8,7 +8,6 @@ import tempfile
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
import pytest import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
@ -41,6 +40,7 @@ from llama_stack.apis.tools import (
ToolInvocationResult, ToolInvocationResult,
) )
from llama_stack.apis.vector_io import QueryChunksResponse from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL, MEMORY_QUERY_TOOL,
) )

View file

@ -23,20 +23,13 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from llama_models.datatypes import (
GreedySamplingStrategy,
SamplingParams,
TopPSamplingStrategy,
)
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import ( from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel from pydantic import BaseModel
@ -47,6 +40,13 @@ from llama_stack.apis.inference import (
ResponseFormatType, ResponseFormatType,
) )
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,

View file

@ -8,14 +8,6 @@ import asyncio
import logging import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
@ -41,6 +33,13 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,

View file

@ -10,10 +10,10 @@ from functools import partial
from typing import Any, Generator from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,

View file

@ -14,14 +14,14 @@ from typing import Any, Dict, List, Optional
import torch import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from torch import Tensor, nn from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
from llama_stack.models.llama.sku_list import resolve_model
from ..config import MetaReferenceQuantizedInferenceConfig from ..config import MetaReferenceQuantizedInferenceConfig

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type

View file

@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams from vllm.sampling_params import SamplingParams as VLLMSamplingParams
@ -35,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,

View file

@ -13,8 +13,6 @@
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
import torch import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from pydantic import BaseModel from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3 import llama3_tokenizer
@ -24,6 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
class ModelConfig(BaseModel): class ModelConfig(BaseModel):

View file

@ -6,8 +6,6 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import webmethod
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
@ -27,6 +25,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice, LoraFinetuningSingleDevice,
) )
from llama_stack.schema_utils import webmethod
class TorchtunePostTrainingImpl: class TorchtunePostTrainingImpl:

View file

@ -14,7 +14,6 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from llama_models.sku_list import resolve_model
from torch import nn from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
@ -46,6 +45,7 @@ from llama_stack.apis.post_training import (
) )
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.validator import ( from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema, validate_input_dataset_schema,
) )

View file

@ -8,9 +8,6 @@ import re
from string import Template from string import Template
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.datatypes import Role
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -26,6 +23,7 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.models.llama.datatypes import CoreModelId, Role
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,

View file

@ -6,13 +6,13 @@
from typing import Any, Dict from typing import Any, Dict
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import ( from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig, KVStoreConfig,
SqliteKVStoreConfig, SqliteKVStoreConfig,
) )
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type

View file

@ -8,7 +8,6 @@ import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -28,6 +27,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (

View file

@ -7,9 +7,7 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -28,6 +26,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -7,9 +7,10 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai" DEFAULT_BASE_URL = "https://api.cerebras.ai"

View file

@ -5,9 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class DatabricksImplConfig(BaseModel): class DatabricksImplConfig(BaseModel):

View file

@ -6,7 +6,6 @@
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
@ -25,6 +24,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -6,9 +6,10 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class FireworksImplConfig(BaseModel): class FireworksImplConfig(BaseModel):

View file

@ -7,7 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -30,6 +29,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -6,9 +6,10 @@
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class GroqConfig(BaseModel): class GroqConfig(BaseModel):

View file

@ -9,9 +9,6 @@ from typing import AsyncIterator, List, Optional, Union
import groq import groq
from groq import Groq from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
from llama_models.sku_list import CoreModelId
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -29,6 +26,8 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -24,7 +24,6 @@ from groq.types.chat.chat_completion_user_message_param import (
) )
from groq.types.chat.completion_create_params import CompletionCreateParams from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.shared.function_definition import FunctionDefinition from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
TextDelta, TextDelta,
@ -44,6 +43,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall, UnparseableToolCall,
convert_tool_call, convert_tool_call,

View file

@ -7,9 +7,10 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class NVIDIAConfig(BaseModel): class NVIDIAConfig(BaseModel):

View file

@ -7,9 +7,6 @@
import warnings import warnings
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
from llama_models.sku_list import CoreModelId
from openai import APIConnectionError, AsyncOpenAI from openai import APIConnectionError, AsyncOpenAI
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -28,6 +25,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.models.llama.datatypes import CoreModelId, SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -8,17 +8,6 @@ import json
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from llama_models.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
)
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
@ -87,6 +76,15 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import (
BuiltinTool,
GreedySamplingStrategy,
StopReason,
ToolCall,
ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
) )

View file

@ -8,7 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -6,9 +6,10 @@
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class RunpodImplConfig(BaseModel): class RunpodImplConfig(BaseModel):

View file

@ -6,11 +6,11 @@
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.models.llama.datatypes import Message
# from llama_stack.providers.datatypes import ModelsProtocolPrivate # from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -6,9 +6,10 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class SambaNovaImplConfig(BaseModel): class SambaNovaImplConfig(BaseModel):

View file

@ -7,12 +7,6 @@
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.datatypes import (
CoreModelId,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
@ -23,6 +17,12 @@ from llama_stack.apis.common.content_types import (
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.models.llama.datatypes import (
CoreModelId,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -6,9 +6,10 @@
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class TGIImplConfig(BaseModel): class TGIImplConfig(BaseModel):

View file

@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -31,6 +30,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -6,9 +6,10 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class TogetherImplConfig(BaseModel): class TogetherImplConfig(BaseModel):

View file

@ -6,7 +6,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together from together import Together
@ -29,6 +28,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -6,9 +6,10 @@
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class VLLMInferenceAdapterConfig(BaseModel): class VLLMInferenceAdapterConfig(BaseModel):

View file

@ -7,10 +7,9 @@ import json
import logging import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api import StopReason, ToolCall from llama_models.datatypes import StopReason, ToolCall
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
@ -37,6 +36,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -5,9 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type

View file

@ -7,7 +7,6 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -18,6 +17,7 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig from .config import BraveSearchToolConfig

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class PGVectorVectorIOConfig(BaseModel): class PGVectorVectorIOConfig(BaseModel):

View file

@ -6,9 +6,10 @@
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class QdrantVectorIOConfig(BaseModel): class QdrantVectorIOConfig(BaseModel):

View file

@ -7,8 +7,6 @@
import os import os
import pytest import pytest
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
@ -25,6 +23,7 @@ from llama_stack.apis.agents import (
) )
from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.safety import ViolationLevel from llama_stack.apis.safety import ViolationLevel
from llama_stack.models.llama.datatypes import BuiltinTool, SamplingParams, TopPSamplingStrategy
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
# How to run this test: # How to run this test:

Some files were not shown because too many files have changed in this diff Show more