mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 16:41:59 +00:00
Merge branch 'main' into add-watsonx-inference-adapter
This commit is contained in:
commit
28e6c8478b
308 changed files with 33749 additions and 5102 deletions
|
|
@ -189,13 +189,11 @@ class AgentToolGroupWithArgs(BaseModel):
|
|||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
AgentToolGroup = Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
]
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
|
|
@ -234,6 +232,23 @@ class AgentConfig(AgentConfigCommon):
|
|||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
|
@ -295,20 +310,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
name="AgentTurnResponseEventPayload",
|
||||
)
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -353,7 +366,7 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
|
||||
tool_responses: List[ToolResponse]
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
|
@ -432,7 +445,7 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
|
@ -443,7 +456,6 @@ class Agents(Protocol):
|
|||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
|
|
@ -541,3 +553,32 @@ class Agents(Protocol):
|
|||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET")
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class Benchmarks(Protocol):
|
|||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
) -> Benchmark: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
async def register_benchmark(
|
||||
|
|
|
|||
|
|
@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
|
|||
|
||||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = register_schema(
|
||||
Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="InterleavedContentItem",
|
||||
)
|
||||
InterleavedContentItem = Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = register_schema(
|
||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
||||
name="InterleavedContent",
|
||||
)
|
||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||
register_schema(InterleavedContent, name="InterleavedContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
|
|||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = register_schema(
|
||||
Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ContentDelta",
|
||||
)
|
||||
ContentDelta = Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ContentDelta, name="ContentDelta")
|
||||
|
|
|
|||
|
|
@ -72,24 +72,22 @@ class DialogType(BaseModel):
|
|||
type: Literal["dialog"] = "dialog"
|
||||
|
||||
|
||||
ParamType = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
name="ParamType",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
||||
"""
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
|
|
|
|||
|
|
@ -13,19 +13,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class PaginatedRowsResult(BaseModel):
|
||||
class IterrowsResponse(BaseModel):
|
||||
"""
|
||||
A paginated list of rows from a dataset.
|
||||
|
||||
:param rows: The rows in the current page.
|
||||
:param total_count: The total number of rows in the dataset.
|
||||
:param next_page_token: The token to get the next page of rows.
|
||||
:param data: The rows in the current page.
|
||||
:param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
|
||||
"""
|
||||
|
||||
# the rows obey the DatasetSchema for the given dataset
|
||||
rows: List[Dict[str, Any]]
|
||||
total_count: int
|
||||
next_page_token: Optional[str] = None
|
||||
data: List[Dict[str, Any]]
|
||||
next_start_index: Optional[int] = None
|
||||
|
||||
|
||||
class DatasetStore(Protocol):
|
||||
|
|
@ -37,22 +34,21 @@ class DatasetIO(Protocol):
|
|||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/datasetio/rows", method="GET")
|
||||
async def get_rows_paginated(
|
||||
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
"""Get a paginated list of rows from a dataset.
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
"""Get a paginated list of rows from a dataset. Uses cursor-based pagination.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param rows_in_page: The number of rows to get per page.
|
||||
:param page_token: The token to get the next page of rows.
|
||||
:param filter_condition: (Optional) A condition to filter the rows by.
|
||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||
:param limit: The number of rows to get.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/rows", method="POST")
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||
|
|
|
|||
|
|
@ -4,19 +4,100 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
class DatasetPurpose(str, Enum):
|
||||
"""
|
||||
Purpose of the dataset. Each purpose has a required input data schema.
|
||||
|
||||
:cvar post-training/messages: The dataset contains messages used for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
:cvar eval/question-answer: The dataset contains a question column and an answer column.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
"""
|
||||
|
||||
post_training_messages = "post-training/messages"
|
||||
eval_question_answer = "eval/question-answer"
|
||||
eval_messages_answer = "eval/messages-answer"
|
||||
|
||||
# TODO: add more schemas here
|
||||
|
||||
|
||||
class DatasetType(Enum):
|
||||
"""
|
||||
Type of the dataset source.
|
||||
:cvar uri: The dataset can be obtained from a URI.
|
||||
:cvar rows: The dataset is stored in rows.
|
||||
"""
|
||||
|
||||
uri = "uri"
|
||||
rows = "rows"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URIDataSource(BaseModel):
|
||||
"""A dataset that can be obtained from a URI.
|
||||
:param uri: The dataset can be obtained from a URI. E.g.
|
||||
- "https://mywebsite.com/mydata.jsonl"
|
||||
- "lsfs://mydata.jsonl"
|
||||
- "data:csv;base64,{base64_content}"
|
||||
"""
|
||||
|
||||
type: Literal["uri"] = "uri"
|
||||
uri: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RowsDataSource(BaseModel):
|
||||
"""A dataset stored in rows.
|
||||
:param rows: The dataset is stored in rows. E.g.
|
||||
- [
|
||||
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
||||
]
|
||||
"""
|
||||
|
||||
type: Literal["rows"] = "rows"
|
||||
rows: List[Dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
dataset_schema: Dict[str, ParamType]
|
||||
url: URL
|
||||
"""
|
||||
Common fields for a dataset.
|
||||
"""
|
||||
|
||||
purpose: DatasetPurpose
|
||||
source: DataSource
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
|
|
@ -38,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
|
|||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
dataset_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
|
|
@ -50,19 +129,75 @@ class Datasets(Protocol):
|
|||
@webmethod(route="/datasets", method="POST")
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Register a new dataset.
|
||||
|
||||
:param purpose: The purpose of the dataset. One of
|
||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "lsfs://mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "data:csv;base64,{base64_content}"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
}
|
||||
- {
|
||||
"type": "rows",
|
||||
"rows": [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
:param metadata: The metadata for the dataset.
|
||||
- E.g. {"description": "My dataset"}
|
||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Optional[Dataset]: ...
|
||||
) -> Dataset: ...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
|
|
|
|||
|
|
@ -43,10 +43,8 @@ class AgentCandidate(BaseModel):
|
|||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = register_schema(
|
||||
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
|
||||
name="EvalCandidate",
|
||||
)
|
||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -117,7 +115,7 @@ class Eval(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ class Files(Protocol):
|
|||
async def get_upload_session_info(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> Optional[FileUploadResponse]:
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
Returns information about an existsing upload session
|
||||
|
||||
|
|
|
|||
|
|
@ -117,13 +117,11 @@ class ToolResponseMessage(BaseModel):
|
|||
|
||||
:param role: Must be "tool" to identify this as a tool response
|
||||
:param call_id: Unique identifier for the tool call this response is for
|
||||
:param tool_name: Name of the tool that was called
|
||||
:param content: The response content from the tool
|
||||
"""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
content: InterleavedContent
|
||||
|
||||
|
||||
|
|
@ -146,18 +144,16 @@ class CompletionMessage(BaseModel):
|
|||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
Field(discriminator="role"),
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
name="Message",
|
||||
)
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -265,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
|
|||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = register_schema(
|
||||
Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ResponseFormat",
|
||||
)
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ResponseFormat, name="ResponseFormat")
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
|
|
@ -285,7 +279,7 @@ class CompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
|
@ -299,7 +293,7 @@ class CompletionResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
|
@ -368,7 +362,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
|
@ -378,7 +372,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
|
|
|||
|
|
@ -11,13 +11,6 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
|
|
@ -36,19 +29,12 @@ class VersionInfo(BaseModel):
|
|||
version: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
data: List[RouteInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class Models(Protocol):
|
|||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Optional[Model]: ...
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
async def register_model(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
|
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -184,7 +182,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
@ -202,10 +200,10 @@ class PostTraining(Protocol):
|
|||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
|||
|
|
@ -4,9 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
from .providers import * # noqa: F401 F403
|
||||
36
llama_stack/apis/providers/providers.py
Normal file
36
llama_stack/apis/providers/providers.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
||||
|
|
@ -36,6 +36,7 @@ class ScoringFnParamsType(Enum):
|
|||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
categorical_count = "categorical_count"
|
||||
accuracy = "accuracy"
|
||||
|
|
@ -78,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
ScoringFnParams = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
name="ScoringFnParams",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
|
|
@ -135,7 +134,7 @@ class ScoringFunctions(Protocol):
|
|||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class Shields(Protocol):
|
|||
async def list_shields(self) -> ListShieldsResponse: ...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
async def register_shield(
|
||||
|
|
|
|||
|
|
@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
|
|||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
metric: str
|
||||
value: Union[int, float]
|
||||
unit: Optional[str] = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||
|
|
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
|
|||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
metrics: Optional[List[MetricEvent]] = None
|
||||
metrics: Optional[List[MetricInResponse]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -139,16 +146,14 @@ class SpanEndPayload(BaseModel):
|
|||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
name="StructuredLogPayload",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -157,17 +162,15 @@ class StructuredLogEvent(EventCommon):
|
|||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
name="Event",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
|
||||
:param document_id: The unique identifier for the document.
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
:param metadata: Additional metadata for the document.
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
|
|
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
|||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class VectorDBs(Protocol):
|
|||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> Optional[VectorDB]: ...
|
||||
) -> VectorDB: ...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now() > manifest.expires_on:
|
||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ class ModelPromptFormat(Subcommand):
|
|||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="llama3_1",
|
||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
||||
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||
"(Run `llama model list` to see a list of valid model names)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
|
|
@ -60,7 +60,6 @@ class ModelPromptFormat(Subcommand):
|
|||
]
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
model_str = "\n".join(model_list)
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
|
|
@ -81,10 +80,16 @@ class ModelPromptFormat(Subcommand):
|
|||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
|
@ -170,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
|
|
@ -213,7 +213,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||
run_with_pty(run_args)
|
||||
run_command(run_args)
|
||||
|
||||
|
||||
def _generate_run_config(
|
||||
|
|
@ -226,7 +226,7 @@ def _generate_run_config(
|
|||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
|
||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
|
|
@ -279,16 +279,16 @@ def _run_stack_build_command_from_build_config(
|
|||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> str:
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
else:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a container image without a template")
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||
image_name = "__system__"
|
||||
if not image_name:
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class StackRun(Subcommand):
|
|||
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
|
|
@ -136,4 +136,4 @@ class StackRun(Subcommand):
|
|||
|
||||
if args.tls_keyfile and args.tls_certfile:
|
||||
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
||||
run_with_pty(run_args)
|
||||
run_command(run_args)
|
||||
|
|
|
|||
81
llama_stack/distribution/access_control.py
Normal file
81
llama_stack/distribution/access_control.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||
3. For each attribute category in the resource's access_attributes:
|
||||
a. If the user lacks that category, access is DENIED
|
||||
b. If the user has the category but none of the required values, access is DENIED
|
||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||
|
||||
Example:
|
||||
# Resource requires:
|
||||
access_attributes = AccessAttributes(
|
||||
roles=["admin", "data-scientist"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
|
||||
# User has:
|
||||
user_attributes = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "infra-team"],
|
||||
"projects": ["llama-3"]
|
||||
}
|
||||
|
||||
# Result: Access GRANTED
|
||||
# - User has the "data-scientist" role (matches one of the required roles)
|
||||
# - AND user is part of the "ml-team" (matches the required team)
|
||||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj: The resource object to check access for
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
|
||||
return True
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
for attr_key, required_values in obj_attributes.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
if not user_values:
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
|
||||
)
|
||||
return False
|
||||
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
|
||||
return True
|
||||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
@ -15,8 +14,8 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -95,7 +94,7 @@ def build_image(
|
|||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
@ -104,7 +103,7 @@ def build_image(
|
|||
container_base,
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
@ -112,7 +111,7 @@ def build_image(
|
|||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
@ -123,11 +122,7 @@ def build_image(
|
|||
if special_deps:
|
||||
args.append("#".join(special_deps))
|
||||
|
||||
is_terminal = sys.stdin.isatty()
|
||||
if is_terminal:
|
||||
return_code = run_with_pty(args)
|
||||
else:
|
||||
return_code = run_command(args)
|
||||
return_code = run_command(args)
|
||||
|
||||
if return_code != 0:
|
||||
log.error(
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ RED='\033[0;31m'
|
|||
NC='\033[0m' # No Color
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ image_tag="$image_name:$version_tag"
|
|||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
CLI_ARGS+=("--platform $BUILD_PLATFORM")
|
||||
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
CLI_ARGS+=("--platform" "linux/arm64")
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
|
|
@ -253,8 +253,7 @@ $CONTAINER_BINARY build \
|
|||
"${CLI_ARGS[@]}" \
|
||||
-t "$image_tag" \
|
||||
-f "$TEMP_DIR/Containerfile" \
|
||||
"." \
|
||||
--progress=plain
|
||||
"."
|
||||
|
||||
# clean up tmp/configs
|
||||
set +x
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
|||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
|
|
@ -31,6 +32,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: Optional[List[str]] = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: Optional[List[str]] = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: Optional[List[str]] = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
|
||||
This class adds an optional access_attributes field that allows fine-grained control
|
||||
over which users can access each resource. When attributes are defined, a user must have
|
||||
matching attributes to access the resource.
|
||||
|
||||
Attribute Matching Algorithm:
|
||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||
|
||||
Examples:
|
||||
# Resource visible to everyone (no access control)
|
||||
model = Model(identifier="llama-2", ...)
|
||||
|
||||
# Resource visible only to admins
|
||||
model = Model(
|
||||
identifier="gpt-4",
|
||||
access_attributes=AccessAttributes(roles=["admin"])
|
||||
)
|
||||
|
||||
# Resource visible to data scientists on the ML team
|
||||
model = Model(
|
||||
identifier="private-model",
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["data-scientist", "researcher"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
)
|
||||
# ^ User must have at least one of the roles AND be on the ml-team
|
||||
|
||||
# Resource visible to users with specific project access
|
||||
vector_db = VectorDB(
|
||||
identifier="customer-embeddings",
|
||||
access_attributes=AccessAttributes(
|
||||
projects=["customer-insights"],
|
||||
namespaces=["confidential"]
|
||||
)
|
||||
)
|
||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithACL(Model, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithACL(Shield, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolWithACL(Tool, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
|
|
@ -45,14 +155,14 @@ RoutableObject = Union[
|
|||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
VectorDBWithACL,
|
||||
DatasetWithACL,
|
||||
ScoringFnWithACL,
|
||||
BenchmarkWithACL,
|
||||
ToolWithACL,
|
||||
ToolGroupWithACL,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -117,6 +227,21 @@ class Provider(BaseModel):
|
|||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: Dict[str, str] = Field(
|
||||
default_factory=Dict,
|
||||
description="""
|
||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
endpoint: str = Field(
|
||||
...,
|
||||
description="Endpoint URL to validate authentication tokens",
|
||||
)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
|
@ -132,6 +257,10 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
auth: Optional[AuthenticationConfig] = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
@ -176,6 +305,8 @@ a default SQLite store will be used.""",
|
|||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.inspect import (
|
||||
HealthInfo,
|
||||
Inspect,
|
||||
ListProvidersResponse,
|
||||
ListRoutesResponse,
|
||||
ProviderInfo,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
|
|
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = []
|
||||
for api, providers in run_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
|
|
@ -44,8 +44,10 @@ from llama_stack.distribution.stack import (
|
|||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
|
|
@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=wrapped_gen,
|
||||
|
|
|
|||
66
llama_stack/distribution/providers.py
Normal file
66
llama_stack/distribution/providers.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ProviderImpl.shutdown")
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
|
@ -7,59 +7,37 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
from typing import Any, ContextManager, Dict, List, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
_ = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
PROVIDER_DATA_VAR.reset(self.token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
|
@ -72,7 +50,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = _provider_data_var.get()
|
||||
val = PROVIDER_DATA_VAR.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
|
@ -107,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
|||
return None
|
||||
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
def request_provider_data_context(
|
||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
|
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
|
|||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
|
|
@ -165,7 +167,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=[info.routing_table_api.value],
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
@ -245,6 +249,25 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
|
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,14 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
|
|
@ -20,6 +22,10 @@ from llama_stack.apis.eval import (
|
|||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
|
@ -27,13 +33,14 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
|
|
@ -42,6 +49,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
|
@ -52,7 +60,10 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -119,9 +130,14 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
|
|
@ -144,6 +160,75 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
model: Model object containing model_id and provider_id
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -156,7 +241,7 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
|
|
@ -206,10 +291,52 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.chat_completion(**params)
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -239,10 +366,41 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.completion(**params)
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -323,21 +481,36 @@ class DatasetIORouter(DatasetIO):
|
|||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||
)
|
||||
await self.routing_table.register_dataset(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata=metadata,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=rows_in_page,
|
||||
page_token=page_token,
|
||||
filter_condition=filter_condition,
|
||||
start_index=start_index,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
|
@ -12,7 +13,16 @@ from pydantic import TypeAdapter
|
|||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||
from llama_stack.apis.datasets import (
|
||||
Dataset,
|
||||
DatasetPurpose,
|
||||
Datasets,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
|
|
@ -31,11 +41,22 @@ from llama_stack.apis.tools import (
|
|||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
BenchmarkWithACL,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithACL,
|
||||
ShieldWithACL,
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
VectorDBWithACL,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
|
@ -176,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj, get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
|
|
@ -192,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
|
|
@ -204,15 +237,24 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
async def get_model(self, model_id: str) -> Optional[Model]:
|
||||
return await self.get_object_by_identifier("model", model_id)
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
return model
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
|
@ -238,7 +280,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
|
@ -259,8 +301,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
async def list_shields(self) -> ListShieldsResponse:
|
||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
shield = await self.get_object_by_identifier("shield", identifier)
|
||||
if shield is None:
|
||||
raise ValueError(f"Shield '{identifier}' not found")
|
||||
return shield
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
|
|
@ -281,7 +326,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
shield = ShieldWithACL(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
|
@ -295,8 +340,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||
|
||||
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
|
||||
return await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
if vector_db is None:
|
||||
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
||||
return vector_db
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
@ -332,7 +380,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
|
@ -347,39 +395,56 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError(f"Dataset '{dataset_id}' not found")
|
||||
return dataset
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if provider_dataset_id is None:
|
||||
provider_dataset_id = dataset_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this dataset
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> Dataset:
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
source = URIDataSource.parse_obj(source)
|
||||
elif source["type"] == "rows":
|
||||
source = RowsDataSource.parse_obj(source)
|
||||
|
||||
if not dataset_id:
|
||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||
|
||||
provider_dataset_id = dataset_id
|
||||
|
||||
# infer provider from source
|
||||
if source.type == DatasetType.rows.value:
|
||||
provider_id = "localfs"
|
||||
elif source.type == DatasetType.uri.value:
|
||||
# infer provider from uri
|
||||
if source.uri.startswith("huggingface"):
|
||||
provider_id = "huggingface"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
)
|
||||
provider_id = "localfs"
|
||||
else:
|
||||
raise ValueError(f"Unknown data source type: {source.type}")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
dataset = Dataset(
|
||||
|
||||
dataset = DatasetWithACL(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
dataset_schema=dataset_schema,
|
||||
url=url,
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.register_object(dataset)
|
||||
return dataset
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
|
|
@ -392,8 +457,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
if scoring_fn is None:
|
||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||
return scoring_fn
|
||||
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
|
|
@ -413,7 +481,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
scoring_fn = ScoringFnWithACL(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
|
|
@ -429,8 +497,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||
|
||||
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
|
||||
return await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||
if benchmark is None:
|
||||
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
||||
return benchmark
|
||||
|
||||
async def register_benchmark(
|
||||
self,
|
||||
|
|
@ -452,7 +523,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = Benchmark(
|
||||
benchmark = BenchmarkWithACL(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
|
|
@ -474,7 +545,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
|
|
@ -492,7 +566,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
|
|
@ -517,7 +591,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
|
@ -530,7 +604,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = await self.list_tools(toolgroup_id).data
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
|
|
|||
203
llama_stack/distribution/server/auth.py
Normal file
203
llama_stack/distribution/server/auth.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using an external auth endpoint.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Sends it to the configured auth endpoint along with request details
|
||||
3. Validates the response and extracts user attributes
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Authentication Request Format:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
"request": {
|
||||
"path": "/models/list",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "..."
|
||||
// All headers except Authorization
|
||||
},
|
||||
"params": {
|
||||
"limit": ["100"],
|
||||
"offset": ["0"]
|
||||
// Query parameters as key -> list of values
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Expected Auth Endpoint Response Format:
|
||||
```json
|
||||
{
|
||||
"access_attributes": { // Structured attribute format
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Optional message about auth result"
|
||||
}
|
||||
```
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth endpoint are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_endpoint):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
|
||||
api_key = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=api_key,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.auth_endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [api_key],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 401,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
|
@ -6,11 +6,9 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
|
@ -27,10 +25,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
|
@ -40,6 +38,7 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
|
|
@ -47,11 +46,13 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|||
TelemetryAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
|
@ -118,69 +119,24 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
def handle_signal(app, signum, _) -> None:
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
Handle incoming signals and initiate a graceful shutdown of the application.
|
||||
|
||||
This function is intended to be used as a signal handler for various signals
|
||||
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
||||
indicating the received signal and initiate a shutdown process.
|
||||
|
||||
Args:
|
||||
app: The application instance containing implementations to be shut down.
|
||||
signum (int): The signal number received.
|
||||
frame: The current stack frame (not used in this function).
|
||||
|
||||
The shutdown process involves:
|
||||
- Shutting down all implementations registered in the application.
|
||||
- Gathering all running asyncio tasks.
|
||||
- Cancelling all gathered tasks.
|
||||
- Waiting for all tasks to finish.
|
||||
- Stopping the event loop.
|
||||
|
||||
Note:
|
||||
This function schedules the shutdown process as an asyncio task and does
|
||||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
||||
|
||||
# Cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
loop.stop()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(shutdown())
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -188,8 +144,7 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Starting up")
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
await shutdown(app)
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
|
@ -224,19 +179,24 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception("Error executing endpoint %s", method, route)
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
|
@ -266,7 +226,7 @@ class TracingMiddleware:
|
|||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
path = scope.get("path", "")
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
|
|
@ -350,34 +310,42 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
log_line = ""
|
||||
if args.yaml_config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
logger_config = None
|
||||
with open(config_file, "r") as fp:
|
||||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
|
@ -387,6 +355,11 @@ def main():
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth and config.server.auth.endpoint:
|
||||
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
|
|
@ -396,7 +369,7 @@ def main():
|
|||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig()))
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
|
|
@ -412,6 +385,7 @@ def main():
|
|||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
@ -439,8 +413,6 @@ def main():
|
|||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
|
|
@ -471,6 +443,8 @@ def main():
|
|||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
|
@ -44,6 +45,7 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
|
|
|
|||
11
llama_stack/distribution/ui/Containerfile
Normal file
11
llama_stack/distribution/ui/Containerfile
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# More info on playground configuration can be found here:
|
||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||
|
||||
FROM python:3.9-slim
|
||||
WORKDIR /app
|
||||
COPY . /app/
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||
/usr/local/bin/pip3 install -r requirements.txt
|
||||
EXPOSE 8501
|
||||
|
||||
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
||||
|
|
@ -40,3 +40,13 @@ cd llama_stack/distribution/ui
|
|||
pip install -r requirements.txt
|
||||
streamlit run app.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Environment Variable | Description | Default Value |
|
||||
|----------------------------|------------------------------------|---------------------------|
|
||||
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
|
||||
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
|
||||
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
||||
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
||||
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def datasets():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def benchmarks():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def models():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def providers():
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from page.distribution.benchmarks import benchmarks
|
||||
from page.distribution.datasets import datasets
|
||||
from page.distribution.models import models
|
||||
from page.distribution.scoring_functions import scoring_functions
|
||||
from page.distribution.shields import shields
|
||||
from page.distribution.vector_dbs import vector_dbs
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
||||
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
||||
from llama_stack.distribution.ui.page.distribution.models import models
|
||||
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.distribution.ui.page.distribution.shields import shields
|
||||
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def scoring_functions():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def shields():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def vector_dbs():
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import process_dataset
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import process_dataset
|
||||
|
||||
|
||||
def application_evaluation_page():
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def select_benchmark_1():
|
||||
|
|
@ -166,11 +167,10 @@ def run_evaluation_3():
|
|||
eval_candidate = st.session_state["eval_candidate"]
|
||||
|
||||
dataset_id = benchmarks[selected_benchmark].dataset_id
|
||||
rows = llama_stack_api.client.datasetio.get_rows_paginated(
|
||||
rows = llama_stack_api.client.datasets.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
total_rows = len(rows.rows)
|
||||
total_rows = len(rows.data)
|
||||
# Add number of examples control
|
||||
num_rows = st.number_input(
|
||||
"Number of Examples to Evaluate",
|
||||
|
|
@ -195,7 +195,7 @@ def run_evaluation_3():
|
|||
if st.button("Run Evaluation"):
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = rows.rows
|
||||
rows = rows.data
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
# Sidebar configurations
|
||||
with st.sidebar:
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@
|
|||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
from llama_stack_client.types.shared.document import Document
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||
|
||||
|
||||
def rag_chat_page():
|
||||
|
|
|
|||
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import AsyncGenerator, List, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
yield item
|
||||
for context_var in context_vars:
|
||||
_ = context_var.set(context_values[context_var.name])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
|
@ -4,13 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
|
@ -20,14 +17,14 @@ import importlib
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == ImageType.container.value or config.container_image:
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
elif image_type == ImageType.conda.value:
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
if not env_name:
|
||||
|
|
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
|||
return run_args
|
||||
|
||||
|
||||
def run_with_pty(command):
|
||||
if sys.platform.startswith("win"):
|
||||
return _run_with_pty_win(command)
|
||||
else:
|
||||
return _run_with_pty_unix(command)
|
||||
|
||||
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
|
@ -108,19 +98,19 @@ def in_notebook():
|
|||
return True
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal, with interrupt handling,
|
||||
# useful when you want to run interactive things
|
||||
def _run_with_pty_unix(command):
|
||||
import pty
|
||||
import termios
|
||||
def run_command(command: list[str]) -> int:
|
||||
"""
|
||||
Run a command with interrupt handling and output capture.
|
||||
Uses subprocess.run with direct stream piping for better performance.
|
||||
|
||||
master, slave = pty.openpty()
|
||||
Args:
|
||||
command (list): The command to run.
|
||||
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
Returns:
|
||||
int: The return code of the command.
|
||||
"""
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
ctrl_c_pressed = False
|
||||
process = None
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
|
|
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
|
|||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
new_settings = termios.tcgetattr(sys.stdin)
|
||||
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
|
||||
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
|
||||
|
||||
process = subprocess.Popen(
|
||||
# Run the command with stdout/stderr piped directly to system streams
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdin=slave,
|
||||
stdout=slave,
|
||||
stderr=slave,
|
||||
universal_newlines=True,
|
||||
preexec_fn=os.setsid,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Close the slave file descriptor as it's now owned by the subprocess
|
||||
os.close(slave)
|
||||
|
||||
def handle_io():
|
||||
while not ctrl_c_pressed:
|
||||
try:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
|
||||
|
||||
if sys.stdin in rlist:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data:
|
||||
break
|
||||
os.write(master, data)
|
||||
|
||||
if master in rlist:
|
||||
data = os.read(master, 1024)
|
||||
if not data:
|
||||
break
|
||||
sys.stdout.buffer.write(data)
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# This will be raised when Ctrl+C is pressed
|
||||
break
|
||||
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
handle_io()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO:
|
||||
raise
|
||||
finally:
|
||||
# Clean up
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
os.close(master)
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
return process.returncode
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal in windows, with interrupt handling,
|
||||
def _run_with_pty_win(command):
|
||||
"""
|
||||
Runs a command with interactive support using subprocess directly.
|
||||
"""
|
||||
try:
|
||||
# For shell scripts on Windows, use appropriate shell
|
||||
if isinstance(command, (list, tuple)):
|
||||
if command[0].endswith(".sh"):
|
||||
if os.path.exists("/usr/bin/bash"): # WSL
|
||||
command = ["bash"] + command
|
||||
else:
|
||||
# Use cmd.exe with bash while preserving all arguments
|
||||
command = ["cmd.exe", "/c", "bash"] + command
|
||||
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
process.wait()
|
||||
|
||||
return result.returncode
|
||||
except subprocess.SubprocessError as e:
|
||||
log.error(f"Subprocess error: {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
log.exception(f"Unexpected error: {e}")
|
||||
return 1
|
||||
finally:
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
return process.returncode
|
||||
|
||||
|
||||
def run_command(command):
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
print("Script Output\n", result.stdout)
|
||||
return result.returncode
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error running script:", e)
|
||||
print("Error output:", e.stderr)
|
||||
return e.returncode
|
||||
# Restore the original signal handler
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import enum
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
class LlamaStackImageType(enum.Enum):
|
||||
CONTAINER = "container"
|
||||
CONDA = "conda"
|
||||
VENV = "venv"
|
||||
|
|
|
|||
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
||||
|
|
@ -7,11 +7,14 @@
|
|||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
from .distribution.datatypes import LoggingConfig
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
|
@ -33,6 +36,56 @@ CATEGORIES = [
|
|||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def config_to_category_levels(category: str, level: str):
|
||||
"""
|
||||
Helper function to be called either by environment parsing or yaml parsing to go from a list of categories and levels to a dictionary ready to be
|
||||
used by the logger dictConfig.
|
||||
|
||||
Parameters:
|
||||
category (str): logging category to apply the level to
|
||||
level (str): logging level to be used in the category
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
|
||||
category_levels: Dict[str, int] = {}
|
||||
level_value = logging._nameToLevel.get(str(level).upper())
|
||||
if level_value is None:
|
||||
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
||||
return category_levels
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
||||
"""
|
||||
Helper function to parse a yaml logging configuration found in the run.yaml
|
||||
|
||||
Parameters:
|
||||
yaml_config (Logging): the logger config object found in the run.yaml
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for category, level in yaml_config.category_levels.items():
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
|
@ -52,25 +105,7 @@ def parse_environment_config(env_config: str) -> Dict[str, int]:
|
|||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
|
||||
level_value = logging._nameToLevel.get(level)
|
||||
if level_value is None:
|
||||
logging.warning(
|
||||
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
|
||||
)
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
|
@ -96,12 +131,13 @@ class CustomRichHandler(RichHandler):
|
|||
self.markup = original_markup
|
||||
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int]) -> None:
|
||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels.
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
log_file (str): Path to a log file to additionally pipe the logs into
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
|
|
@ -116,6 +152,28 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
handlers = {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use custom console handler
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
}
|
||||
|
||||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
|
|
@ -125,17 +183,7 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use our custom handler class
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
},
|
||||
"handlers": handlers,
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
|
|
@ -143,21 +191,28 @@ def setup_logging(category_levels: Dict[str, int]) -> None:
|
|||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": ["console"],
|
||||
"handlers": list(handlers.keys()), # Apply all handlers
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"handlers": list(handlers.keys()),
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
dictConfig(logging_config)
|
||||
|
||||
# Ensure third-party libraries follow the root log level
|
||||
for _, logger in logging.root.manager.loggerDict.items():
|
||||
if isinstance(logger, logging.Logger):
|
||||
logger.setLevel(root_level)
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
|
||||
def get_logger(
|
||||
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None
|
||||
) -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
|
@ -165,10 +220,14 @@ def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdap
|
|||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
config (Logging): optional yaml config to override the existing logger configuration
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
if config:
|
||||
_category_levels.update(parse_yaml_config(config))
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
|
@ -176,7 +235,9 @@ def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdap
|
|||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
|
||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
setup_logging(_category_levels)
|
||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||
|
||||
setup_logging(_category_levels, log_file)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
|||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
arguments: Dict[str, RecursiveType]
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: Union[str, Dict[str, RecursiveType]]
|
||||
arguments_json: Optional[str] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -179,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
|
|||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = register_schema(
|
||||
Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="SamplingStrategy",
|
||||
)
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
|
@ -203,9 +204,10 @@ class ChatFormat:
|
|||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
if isinstance(tool_arguments, dict):
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
|
|
@ -229,6 +231,7 @@ class ChatFormat:
|
|||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
|||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"today": datetime.now().strftime("%d %B %Y")},
|
||||
{
|
||||
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
|
||||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
|
|
|
|||
|
|
@ -11,11 +11,8 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,11 @@ import json
|
|||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
||||
|
|
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
|
|||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
try:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
|
||||
) from e
|
||||
|
||||
result.append((function_name, function_args))
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -153,7 +153,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
|
|
@ -181,7 +180,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
|
|
@ -191,7 +191,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
await self._initialize_tools()
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
|
@ -218,18 +219,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
if isinstance(request.tool_responses[0], ToolResponseMessage):
|
||||
tool_response_messages = request.tool_responses
|
||||
tool_responses = [
|
||||
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
else:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
tool_responses = request.tool_responses
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
|
|
@ -247,12 +239,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now().astimezone().isoformat()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=tool_responses,
|
||||
tool_responses=request.tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
|
@ -272,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
start_time = datetime.now(timezone.utc).isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
|
|
@ -283,7 +275,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
|
|
@ -304,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
|
@ -335,7 +326,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
|
@ -358,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -390,14 +379,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now().astimezone().isoformat()
|
||||
shield_call_start_time = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
@ -421,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -444,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -459,30 +448,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroups.add(tool_group_name)
|
||||
toolgroup_args[tool_group_name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
await self.handle_documents(session_id, documents, input_messages)
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info and session_info.vector_db_id:
|
||||
if RAG_TOOL_GROUP not in toolgroup_args:
|
||||
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||
else:
|
||||
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
for tool_name in self.tool_name_to_args.keys():
|
||||
if tool_name == MEMORY_QUERY_TOOL:
|
||||
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
||||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
|
@ -494,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now().astimezone().isoformat()
|
||||
inference_start_time = datetime.now(timezone.utc).isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
|
|
@ -508,11 +486,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
async with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=tool_defs,
|
||||
tools=self.tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
|
|
@ -604,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -636,125 +614,143 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
tool_call = message.tool_calls[0]
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
input_messages = input_messages + [message]
|
||||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
message.stop_reason = StopReason.end_of_message
|
||||
# Process tool calls in the message
|
||||
client_tool_calls = []
|
||||
non_client_tool_calls = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
|
||||
# Process non-client tool calls first
|
||||
for tool_call in non_client_tool_calls:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Execute the tool call
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
)
|
||||
result_message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
content=tool_result.content,
|
||||
)
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
# Yield the step completion event
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Add the result message to input_messages for the next iteration
|
||||
input_messages.append(result_message)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
# If there are client tool calls, yield a message with only those tool calls
|
||||
if client_tool_calls:
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_calls=client_tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now().astimezone().isoformat(),
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
|
||||
# Create a copy of the message with only client tool calls
|
||||
client_message = message.model_copy(deep=True)
|
||||
client_message.tool_calls = client_tool_calls
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
client_message.stop_reason = StopReason.end_of_message
|
||||
|
||||
# Yield the message with client tool calls
|
||||
yield client_message
|
||||
return
|
||||
|
||||
# If tool is a builtin server tool, execute it
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
)
|
||||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
agent_config_toolgroups = []
|
||||
|
|
@ -763,8 +759,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if name not in agent_config_toolgroups:
|
||||
agent_config_toolgroups.append(name)
|
||||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
tool_name_to_args = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
|
|
@ -782,53 +780,38 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if not tools.data:
|
||||
available_tool_groups = ", ".join(
|
||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
tool_name = tool_def.identifier
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
if tool_name == "web_search":
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.identifier):
|
||||
identifier = tool_def.identifier
|
||||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_name_to_def.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
if tool_name in (None, tool_def.identifier):
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
|
|
@ -840,9 +823,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
return list(tool_name_to_def.values()), tool_to_group
|
||||
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
|
@ -861,15 +844,46 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_group, tool_name = split_names[0], None
|
||||
return tool_group, tool_name
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
) -> ToolInvocationResult:
|
||||
tool_name = tool_call.tool_name
|
||||
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
|
||||
if tool_name not in registered_tool_names:
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
|
||||
)
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
if tool_name == BuiltinTool.brave_search:
|
||||
tool_name_str = WEB_SEARCH_TOOL
|
||||
else:
|
||||
tool_name_str = tool_name.value
|
||||
else:
|
||||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
|
@ -892,16 +906,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
|
|
@ -968,8 +980,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
|||
return data
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
|
||||
contents = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
|
|
@ -989,48 +1001,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(
|
||||
contents.append(
|
||||
TextContentItem(
|
||||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||
)
|
||||
)
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> ToolInvocationResult:
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = WEB_SEARCH_TOOL
|
||||
if isinstance(message.content, list):
|
||||
message.content.extend(contents)
|
||||
else:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContentItem(text=message.content)] + contents
|
||||
else:
|
||||
name = name.value
|
||||
|
||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
message.content = [message.content] + contents
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
|
|
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
|
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
|
|
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
|
|
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
|
|
@ -169,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
|
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
|
|
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
|
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -36,7 +36,7 @@ class AgentPersistence:
|
|||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import List
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to localfs storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="localfs_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,20 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
|
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
|
|||
DATASETS_PREFIX = "localfs_datasets:"
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
dataset_def: Dataset
|
||||
dataset_impl: BaseDataset
|
||||
|
||||
|
||||
class PandasDataframeDataset(BaseDataset):
|
||||
class PandasDataframeDataset:
|
||||
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_def = dataset_def
|
||||
|
|
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
|
|||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
||||
# note that we will drop any columns in dataset that are not in the schema
|
||||
df = df[self.dataset_def.dataset_schema.keys()]
|
||||
# check all columns in dataset schema are present
|
||||
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
||||
# TODO: type checking against column types in dataset schema
|
||||
return df
|
||||
|
||||
def load(self) -> None:
|
||||
async def load(self) -> None:
|
||||
if self.df is not None:
|
||||
return
|
||||
|
||||
df = get_dataframe_from_url(self.dataset_def.url)
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
if self.dataset_def.source.type == "uri":
|
||||
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
|
||||
self.df = self._validate_dataset_schema(df)
|
||||
if self.df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
||||
|
||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
|
|
@ -99,95 +66,55 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
dataset_impl = PandasDataframeDataset(dataset)
|
||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||
dataset_def=dataset,
|
||||
dataset_impl=dataset_impl,
|
||||
)
|
||||
self.dataset_infos[dataset.identifier] = dataset
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
dataset_def: Dataset,
|
||||
) -> None:
|
||||
# Store in kvstore
|
||||
key = f"{DATASETS_PREFIX}{dataset.identifier}"
|
||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=dataset.json(),
|
||||
)
|
||||
dataset_impl = PandasDataframeDataset(dataset)
|
||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||
dataset_def=dataset,
|
||||
dataset_impl=dataset_impl,
|
||||
value=dataset_def.model_dump_json(),
|
||||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
dataset_info.dataset_impl.load()
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
start_index = start_index or 0
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
if limit is None or limit == -1:
|
||||
end = len(dataset_impl)
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
end = min(start_index + limit, len(dataset_impl))
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
end = len(dataset_info.dataset_impl)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
rows = dataset_impl[start_index:end]
|
||||
|
||||
rows = dataset_info.dataset_impl[start:end]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_start_index=end if end < len(dataset_impl) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
if dataset_info is None:
|
||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||
|
||||
dataset_impl = dataset_info.dataset_impl
|
||||
dataset_impl.load()
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
file_path = parsed_url.path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
dataset_impl.df.to_csv(file_path, index=False)
|
||||
elif parsed_url.scheme == "data":
|
||||
# For data URLs, we need to update the base64-encoded content
|
||||
if not parsed_url.path.startswith("text/csv;base64,"):
|
||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
||||
|
||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
|
||||
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
|
||||
) # Uses SQLite config specific to Meta Reference Eval storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="meta_reference_eval.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,18 +12,13 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, UserMessage
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
|
|
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
|
|||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
|
||||
# TODO (xiyan): validate dataset schema
|
||||
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
|
@ -118,7 +115,7 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
|
|
@ -168,10 +165,11 @@ class MetaReferenceEvalImpl(
|
|||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
|
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
|||
|
||||
def parse_message(json_str: str) -> ProcessingMessage:
|
||||
data = json.loads(json_str)
|
||||
return ProcessingMessageWrapper(**data).payload
|
||||
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
|
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -40,7 +42,7 @@ class VLLMConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
|
|
|
|||
|
|
@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
arguments_json=t.function.arguments,
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
DialogType,
|
||||
|
|
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = {
|
||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||
"instruct": [
|
||||
{
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
|
|
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
|
|||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def:
|
||||
raise ValueError(f"Dataset {dataset_id} does not exist.")
|
||||
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
|
|||
checkpoint_files: List[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
):
|
||||
# Fail fast if ``checkpoint_files`` is invalid
|
||||
# TODO: support loading more than one file
|
||||
if len(checkpoint_files) != 1:
|
||||
|
|
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
|
|||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str:Any] = {}
|
||||
state_dict: Dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
|
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
|
|||
state_dict: Dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str = "meta",
|
||||
checkpoint_format: str | None = None,
|
||||
) -> str:
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
if checkpoint_format == "meta":
|
||||
if checkpoint_format == "meta" or checkpoint_format is None:
|
||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||
elif checkpoint_format == "huggingface":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
|
|||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
model_definition: BuildLoraModelCallable
|
||||
tokenizer_type: BuildTokenizerCallable
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
|
|
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
|
|||
}
|
||||
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -12,3 +12,9 @@ from pydantic import BaseModel
|
|||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class SFTDataset(Dataset):
|
|||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
||||
tokenized_dict = self._model_transform(transformed_sample)
|
||||
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
|
||||
|
||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||
keys_str = ", ".join(tokenized_dict.keys())
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
|
|||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
@ -61,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
|
|
@ -81,7 +84,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
|
@ -90,7 +93,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
|
@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
|
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Currently logging only logs limited training metrics to local disk
|
||||
# will figure out more loggings and how it works with telemetry in future PRs
|
||||
|
||||
_checkpointer: TorchtuneCheckpointer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
|
|
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
|
|||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
|
|
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
|
|||
return str(checkpoint_dir)
|
||||
|
||||
if checkpoint_dir and checkpoint_dir != "null":
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
model_obj = resolve_model(self.model_id)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model_obj)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
self._checkpoint_format = config.checkpoint_format
|
||||
|
|
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
|
|||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(training_config.efficiency_config.enable_activation_offloading)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
|
||||
self._enable_activation_checkpointing = False
|
||||
self._enable_activation_offloading = False
|
||||
if training_config.efficiency_config:
|
||||
if training_config.efficiency_config.enable_activation_checkpointing:
|
||||
self._enable_activation_checkpointing = (
|
||||
training_config.efficiency_config.enable_activation_checkpointing
|
||||
)
|
||||
if training_config.efficiency_config.enable_activation_offloading:
|
||||
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
|
@ -328,13 +331,13 @@ class LoraFinetuningSingleDevice:
|
|||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
|
||||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.rows
|
||||
rows = all_rows.data
|
||||
|
||||
await validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
|
|
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
|
|||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
running_loss: float = 0.0
|
||||
num_tokens = 0
|
||||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats = {}
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
|
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss
|
||||
running_loss += current_loss.detach().item()
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
|
|
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss.item() / num_tokens
|
||||
loss_to_log = running_loss / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||
|
|
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
running_loss = 0.0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
|
|
@ -532,7 +535,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -227,13 +227,6 @@ class LlamaGuardShield:
|
|||
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
for i, m in enumerate(messages):
|
||||
print(f"{i}: {m.role}: {m.content}")
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
|
@ -23,3 +24,9 @@ class PromptGuardConfig(BaseModel):
|
|||
if v not in [t.value for t in PromptGuardType]:
|
||||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel): ...
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -22,11 +22,25 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||
RegexParserMathResponseScoringFn,
|
||||
)
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
FIXED_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
@ -74,12 +88,12 @@ class BasicScoringImpl(
|
|||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue