chore: move all Llama Stack types from llama-models to llama-stack (#1098)

llama-models should have extremely minimal cruft. Its sole purpose
should be didactic -- show the simplest implementation of the llama
models and document the prompt formats, etc.

This PR is the complement to
https://github.com/meta-llama/llama-models/pull/279

## Test Plan

Ensure all `llama` CLI `model` sub-commands work:

```bash
llama model list
llama model download --model-id ...
llama model prompt-format -m ...
```

Ran tests:
```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/
LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/
LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/
```

Create a fresh venv `uv venv && source .venv/bin/activate` and run
`llama stack build --template fireworks --image-type venv` followed by
`llama stack run together --image-type venv` <-- the server runs

Also checked that the OpenAPI generator can run and there is no change
in the generated files as a result.

```bash
cd docs/openapi_generator
sh run_openapi_generator.sh
```
This commit is contained in:
Ashwin Bharambe 2025-02-14 09:10:59 -08:00 committed by GitHub
parent c0ee512980
commit 314ee09ae3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
138 changed files with 8491 additions and 465 deletions

View file

@ -19,7 +19,6 @@ from typing import (
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
@ -38,6 +37,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class Attachment(BaseModel):

View file

@ -1,206 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent:
def __init__(
self,
role: Optional[str] = None,
content: str = "",
end: str = "\n",
color="white",
):
self.role = role
self.content = content
self.color = color
self.end = "\n" if end is None else end
def __str__(self):
if self.role is not None:
return f"{self.role}> {self.content}"
else:
return f"{self.content}"
def print(self, flush=True):
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
EventType = AgentTurnResponseEventType
class EventLogger:
async def log(
self,
event_generator,
stream=True,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
):
previous_event_type = None
previous_step_type = None
async for chunk in event_generator:
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield (
chunk,
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
)
continue
event = chunk.event
event_type = event.payload.event_type
if event_type in {
EventType.turn_start.value,
EventType.turn_complete.value,
}:
# Currently not logging any turn realted info
yield event, None
continue
step_type = event.payload.step_type
# handle safety
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
violation = event.payload.step_details.violation
if not violation:
yield (
event,
LogEvent(role=step_type, content="No Violation", color="magenta"),
)
else:
yield (
event,
LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
)
# handle inference
if step_type == StepType.inference:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
yield (
event,
LogEvent(role=step_type, content="", end="", color="yellow"),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
# this is the first time we are getting model inference response
# aka equivalent to step_start for inference. Hence,
# start with "Model>".
if (
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
yield (
event,
LogEvent(role=step_type, content="", end="", color="yellow"),
)
delta = event.payload.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
yield (
event,
LogEvent(
role=None,
content=delta.tool_call,
end="",
color="cyan",
),
)
else:
yield (
event,
LogEvent(
role=None,
content=delta.text,
end="",
color="yellow",
),
)
else:
# step_complete
yield event, LogEvent(role=None, content="")
else:
# Not streaming
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
else:
content = response.content
yield (
event,
LogEvent(
role=step_type,
content=content,
color="yellow",
),
)
# handle tool_execution
if (
step_type == StepType.tool_execution
and
# Only print tool calls and responses at the step_complete event
event_type == EventType.step_complete.value
):
details = event.payload.step_details
for t in details.tool_calls:
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
)
for r in details.tool_responses:
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
)
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
yield (
event,
LogEvent(
role=step_type,
content=content,
color="cyan",
),
)
previous_event_type = event_type
previous_step_type = step_type

View file

@ -6,7 +6,6 @@
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.inference import (
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel):

View file

@ -7,10 +7,11 @@
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator
from llama_stack.models.llama.datatypes import ToolCall
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class URL(BaseModel):

View file

@ -7,10 +7,10 @@
from enum import Enum
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
from llama_stack.schema_utils import json_schema_type
@json_schema_type

View file

@ -5,9 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):

View file

@ -7,9 +7,10 @@
from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PostTrainingMetric(BaseModel):

View file

@ -6,10 +6,11 @@
from typing import Literal, Union
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class StringType(BaseModel):

View file

@ -6,10 +6,10 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.datasets import Dataset
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type

View file

@ -6,12 +6,12 @@
from typing import Any, Dict, List, Literal, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonDatasetFields(BaseModel):

View file

@ -6,7 +6,7 @@
from enum import Enum
from llama_models.schema_utils import json_schema_type
from llama_stack.schema_utils import json_schema_type
@json_schema_type

View file

@ -6,7 +6,6 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type

View file

@ -17,7 +17,13 @@ from typing import (
runtime_checkable,
)
from llama_models.llama3.api.datatypes import (
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
SamplingParams,
StopReason,
@ -25,14 +31,8 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class LogProbConfig(BaseModel):

View file

@ -6,9 +6,10 @@
from typing import List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):

View file

@ -7,11 +7,11 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonModelFields(BaseModel):

View file

@ -8,13 +8,13 @@ from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type

View file

@ -7,12 +7,12 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type

View file

@ -6,10 +6,10 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
ScoringResultRow = Dict[str, Any]

View file

@ -16,12 +16,12 @@ from typing import (
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated

View file

@ -6,11 +6,11 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonShieldFields(BaseModel):

View file

@ -7,10 +7,10 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.inference import Message
from llama_stack.schema_utils import json_schema_type, webmethod
class FilteringFunction(Enum):

View file

@ -17,11 +17,12 @@ from typing import (
runtime_checkable,
)
from llama_models.llama3.api.datatypes import Primitive
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7

View file

@ -7,12 +7,12 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type

View file

@ -7,13 +7,13 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime

View file

@ -6,11 +6,11 @@
from typing import List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type

View file

@ -10,12 +10,12 @@
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):