mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 08:20:00 +00:00
Merge branch 'main' of https://github.com/meta-llama/llama-stack into add_nemo_customizer
This commit is contained in:
commit
f534b4c2ea
571 changed files with 229651 additions and 12956 deletions
|
|
@ -41,16 +41,36 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
:param content: The content of the attachment.
|
||||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""A document to be used by an agent.
|
||||
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
"""A common step in an agent turn.
|
||||
|
||||
:param turn_id: The ID of the turn.
|
||||
:param step_id: The ID of the step.
|
||||
:param started_at: The time the step started.
|
||||
:param completed_at: The time the step completed.
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
|
|
@ -58,6 +78,14 @@ class StepCommon(BaseModel):
|
|||
|
||||
|
||||
class StepType(Enum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
||||
"""
|
||||
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
|
|
@ -66,6 +94,11 @@ class StepType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
||||
:param model_response: The response from the LLM.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
|
|
@ -74,6 +107,12 @@ class InferenceStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
"""A tool execution step in an agent turn.
|
||||
|
||||
:param tool_calls: The tool calls to execute.
|
||||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
|
|
@ -81,13 +120,25 @@ class ToolExecutionStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
"""A shield call step in an agent turn.
|
||||
|
||||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
violation: Optional[SafetyViolation]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
"""A memory retrieval step in an agent turn.
|
||||
|
||||
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
|
@ -138,17 +189,15 @@ class AgentToolGroupWithArgs(BaseModel):
|
|||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
AgentToolGroup = Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
]
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
|
@ -183,6 +232,23 @@ class AgentConfig(AgentConfigCommon):
|
|||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
|
@ -244,20 +310,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
name="AgentTurnResponseEventPayload",
|
||||
)
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -296,16 +360,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
stream: Optional[bool] = False
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||
allow_turn_resume: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: List[ToolResponseMessage]
|
||||
tool_responses: List[ToolResponse]
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
|
@ -338,7 +399,13 @@ class Agents(Protocol):
|
|||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
) -> AgentCreateResponse:
|
||||
"""Create an agent with the given configuration.
|
||||
|
||||
:param agent_config: The configuration for the agent.
|
||||
:returns: An AgentCreateResponse with the agent ID.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||
async def create_agent_turn(
|
||||
|
|
@ -355,8 +422,19 @@ class Agents(Protocol):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
:param session_id: The ID of the session to create the turn for.
|
||||
:param messages: List of messages to start the turn with.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param documents: (Optional) List of documents to create the turn with.
|
||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
||||
"""
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
|
|
@ -367,7 +445,7 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
|
@ -392,7 +470,15 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
) -> Turn:
|
||||
"""Retrieve an agent turn by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the turn for.
|
||||
:param session_id: The ID of the session to get the turn for.
|
||||
:param turn_id: The ID of the turn to get.
|
||||
:returns: A Turn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
|
|
@ -404,14 +490,30 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
step_id: str,
|
||||
) -> AgentStepResponse: ...
|
||||
) -> AgentStepResponse:
|
||||
"""Retrieve an agent step by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the step for.
|
||||
:param session_id: The ID of the session to get the step for.
|
||||
:param turn_id: The ID of the turn to get the step for.
|
||||
:param step_id: The ID of the step to get.
|
||||
:returns: An AgentStepResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse: ...
|
||||
) -> AgentSessionCreateResponse:
|
||||
"""Create a new session for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the session for.
|
||||
:param session_name: The name of the session to create.
|
||||
:returns: An AgentSessionCreateResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
|
||||
async def get_agents_session(
|
||||
|
|
@ -419,17 +521,64 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE")
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Delete an agent by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET")
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
|
|
@ -50,7 +50,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class Benchmarks(Protocol):
|
|||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
) -> Benchmark: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
async def register_benchmark(
|
||||
|
|
|
|||
|
|
@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
|
|||
|
||||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = register_schema(
|
||||
Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="InterleavedContentItem",
|
||||
)
|
||||
InterleavedContentItem = Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = register_schema(
|
||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
||||
name="InterleavedContent",
|
||||
)
|
||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||
register_schema(InterleavedContent, name="InterleavedContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
|
|||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = register_schema(
|
||||
Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ContentDelta",
|
||||
)
|
||||
ContentDelta = Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ContentDelta, name="ContentDelta")
|
||||
|
|
|
|||
|
|
@ -72,24 +72,22 @@ class DialogType(BaseModel):
|
|||
type: Literal["dialog"] = "dialog"
|
||||
|
||||
|
||||
ParamType = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
name="ParamType",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
||||
"""
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
|
|
|
|||
|
|
@ -13,11 +13,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class PaginatedRowsResult(BaseModel):
|
||||
# the rows obey the DatasetSchema for the given dataset
|
||||
rows: List[Dict[str, Any]]
|
||||
total_count: int
|
||||
next_page_token: Optional[str] = None
|
||||
class IterrowsResponse(BaseModel):
|
||||
"""
|
||||
A paginated list of rows from a dataset.
|
||||
|
||||
:param data: The rows in the current page.
|
||||
:param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
|
||||
"""
|
||||
|
||||
data: List[Dict[str, Any]]
|
||||
next_start_index: Optional[int] = None
|
||||
|
||||
|
||||
class DatasetStore(Protocol):
|
||||
|
|
@ -29,14 +34,21 @@ class DatasetIO(Protocol):
|
|||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/datasetio/rows", method="GET")
|
||||
async def get_rows_paginated(
|
||||
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult: ...
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
"""Get a paginated list of rows from a dataset. Uses cursor-based pagination.
|
||||
|
||||
@webmethod(route="/datasetio/rows", method="POST")
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||
:param limit: The number of rows to get.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||
|
|
|
|||
|
|
@ -4,19 +4,100 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
class DatasetPurpose(str, Enum):
|
||||
"""
|
||||
Purpose of the dataset. Each purpose has a required input data schema.
|
||||
|
||||
:cvar post-training/messages: The dataset contains messages used for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
:cvar eval/question-answer: The dataset contains a question column and an answer column.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
"""
|
||||
|
||||
post_training_messages = "post-training/messages"
|
||||
eval_question_answer = "eval/question-answer"
|
||||
eval_messages_answer = "eval/messages-answer"
|
||||
|
||||
# TODO: add more schemas here
|
||||
|
||||
|
||||
class DatasetType(Enum):
|
||||
"""
|
||||
Type of the dataset source.
|
||||
:cvar uri: The dataset can be obtained from a URI.
|
||||
:cvar rows: The dataset is stored in rows.
|
||||
"""
|
||||
|
||||
uri = "uri"
|
||||
rows = "rows"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URIDataSource(BaseModel):
|
||||
"""A dataset that can be obtained from a URI.
|
||||
:param uri: The dataset can be obtained from a URI. E.g.
|
||||
- "https://mywebsite.com/mydata.jsonl"
|
||||
- "lsfs://mydata.jsonl"
|
||||
- "data:csv;base64,{base64_content}"
|
||||
"""
|
||||
|
||||
type: Literal["uri"] = "uri"
|
||||
uri: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RowsDataSource(BaseModel):
|
||||
"""A dataset stored in rows.
|
||||
:param rows: The dataset is stored in rows. E.g.
|
||||
- [
|
||||
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
||||
]
|
||||
"""
|
||||
|
||||
type: Literal["rows"] = "rows"
|
||||
rows: List[Dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
dataset_schema: Dict[str, ParamType]
|
||||
url: URL
|
||||
"""
|
||||
Common fields for a dataset.
|
||||
"""
|
||||
|
||||
purpose: DatasetPurpose
|
||||
source: DataSource
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
|
|
@ -38,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
|
|||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
dataset_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
|
|
@ -50,19 +129,75 @@ class Datasets(Protocol):
|
|||
@webmethod(route="/datasets", method="POST")
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Register a new dataset.
|
||||
|
||||
:param purpose: The purpose of the dataset. One of
|
||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "lsfs://mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "data:csv;base64,{base64_content}"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
}
|
||||
- {
|
||||
"type": "rows",
|
||||
"rows": [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
:param metadata: The metadata for the dataset.
|
||||
- E.g. {"description": "My dataset"}
|
||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Optional[Dataset]: ...
|
||||
) -> Dataset: ...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
|
|
|
|||
|
|
@ -5,12 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
|
|
@ -33,3 +37,20 @@ class Api(Enum):
|
|||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Error(BaseModel):
|
||||
"""
|
||||
Error response from the API. Roughly follows RFC 7807.
|
||||
|
||||
:param status: HTTP status code
|
||||
:param title: Error title, a short summary of the error which is invariant for an error type
|
||||
:param detail: Error detail, a longer human-readable description of the error
|
||||
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
||||
"""
|
||||
|
||||
status: int
|
||||
title: str
|
||||
detail: str
|
||||
instance: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -19,6 +19,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class ModelCandidate(BaseModel):
|
||||
"""A model candidate for evaluation.
|
||||
|
||||
:param model: The model ID to evaluate.
|
||||
:param sampling_params: The sampling parameters for the model.
|
||||
:param system_message: (Optional) The system message providing instructions or context to the model.
|
||||
"""
|
||||
|
||||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
|
|
@ -27,18 +34,28 @@ class ModelCandidate(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = register_schema(
|
||||
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
|
||||
name="EvalCandidate",
|
||||
)
|
||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""A benchmark configuration for evaluation.
|
||||
|
||||
:param eval_candidate: The candidate to evaluate.
|
||||
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
||||
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
||||
"""
|
||||
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
|
|
@ -53,18 +70,32 @@ class BenchmarkConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""The response from an evaluation.
|
||||
|
||||
:param generations: The generations from the evaluation.
|
||||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: List[Dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job: ...
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
|
|
@ -72,14 +103,41 @@ class Eval(Protocol):
|
|||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ class Files(Protocol):
|
|||
async def get_upload_session_info(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> Optional[FileUploadResponse]:
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
Returns information about an existsing upload session
|
||||
|
||||
|
|
|
|||
|
|
@ -117,13 +117,11 @@ class ToolResponseMessage(BaseModel):
|
|||
|
||||
:param role: Must be "tool" to identify this as a tool response
|
||||
:param call_id: Unique identifier for the tool call this response is for
|
||||
:param tool_name: Name of the tool that was called
|
||||
:param content: The response content from the tool
|
||||
"""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
content: InterleavedContent
|
||||
|
||||
|
||||
|
|
@ -146,18 +144,16 @@ class CompletionMessage(BaseModel):
|
|||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
Field(discriminator="role"),
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
name="Message",
|
||||
)
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -265,27 +261,25 @@ class GrammarResponseFormat(BaseModel):
|
|||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = register_schema(
|
||||
Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ResponseFormat",
|
||||
)
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ResponseFormat, name="ResponseFormat")
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedContent
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
|
@ -299,7 +293,7 @@ class CompletionResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
|
@ -357,7 +351,7 @@ class ToolConfig(BaseModel):
|
|||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
||||
|
|
@ -368,7 +362,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
|
@ -378,7 +372,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
|
@ -444,7 +438,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -467,7 +461,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
|
|||
|
|
@ -11,13 +11,6 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
|
|
@ -36,19 +29,12 @@ class VersionInfo(BaseModel):
|
|||
version: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
data: List[RouteInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class Models(Protocol):
|
|||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Optional[Model]: ...
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
async def register_model(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
|
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -184,7 +182,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
@ -202,10 +200,10 @@ class PostTraining(Protocol):
|
|||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .providers import * # noqa: F401 F403
|
||||
36
llama_stack/apis/providers/providers.py
Normal file
36
llama_stack/apis/providers/providers.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
||||
|
|
@ -17,6 +17,13 @@ ScoringResultRow = Dict[str, Any]
|
|||
|
||||
@json_schema_type
|
||||
class ScoringResult(BaseModel):
|
||||
"""
|
||||
A scoring result for a single row.
|
||||
|
||||
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
|
||||
:param aggregated_results: Map of metric name to aggregated value
|
||||
"""
|
||||
|
||||
score_rows: List[ScoringResultRow]
|
||||
# aggregated metrics to value
|
||||
aggregated_results: Dict[str, Any]
|
||||
|
|
@ -30,6 +37,12 @@ class ScoreBatchResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoreResponse(BaseModel):
|
||||
"""
|
||||
The response from scoring.
|
||||
|
||||
:param results: A map of scoring function name to ScoringResult.
|
||||
"""
|
||||
|
||||
# each key in the dict is a scoring function name
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
|
@ -55,4 +68,11 @@ class Scoring(Protocol):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
) -> ScoreResponse: ...
|
||||
) -> ScoreResponse:
|
||||
"""Score a list of rows.
|
||||
|
||||
:param input_rows: The rows to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:return: ScoreResponse object containing rows and aggregated results
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class ScoringFnParamsType(Enum):
|
|||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
categorical_count = "categorical_count"
|
||||
accuracy = "accuracy"
|
||||
|
|
@ -78,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
ScoringFnParams = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
name="ScoringFnParams",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
|
|
@ -135,7 +134,7 @@ class ScoringFunctions(Protocol):
|
|||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class Shields(Protocol):
|
|||
async def list_shields(self) -> ListShieldsResponse: ...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
async def register_shield(
|
||||
|
|
|
|||
|
|
@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
|
|||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
metric: str
|
||||
value: Union[int, float]
|
||||
unit: Optional[str] = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||
|
|
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
|
|||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
metrics: Optional[List[MetricEvent]] = None
|
||||
metrics: Optional[List[MetricInResponse]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -139,16 +146,14 @@ class SpanEndPayload(BaseModel):
|
|||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
name="StructuredLogPayload",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -157,17 +162,15 @@ class StructuredLogEvent(EventCommon):
|
|||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
name="Event",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
|
||||
:param document_id: The unique identifier for the document.
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
:param metadata: Additional metadata for the document.
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
|
|
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
|||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class VectorDBs(Protocol):
|
|||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> Optional[VectorDB]: ...
|
||||
) -> VectorDB: ...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
|
@ -343,7 +343,7 @@ def _hf_download(
|
|||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
|
||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
|
||||
except Exception as e:
|
||||
parser.error(e)
|
||||
|
||||
|
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now() > manifest.expires_on:
|
||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import argparse
|
|||
from .download import Download
|
||||
from .model import ModelParser
|
||||
from .stack import StackParser
|
||||
from .stack.utils import print_subcommand_description
|
||||
from .verify_download import VerifyDownload
|
||||
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class LlamaCLIParser:
|
|||
prog="llama",
|
||||
description="Welcome to the Llama CLI",
|
||||
add_help=True,
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
# Default command is to print help
|
||||
|
|
@ -33,6 +35,8 @@ class LlamaCLIParser:
|
|||
Download.create(subparsers)
|
||||
VerifyDownload.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return self.parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@
|
|||
import argparse
|
||||
import json
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
|
@ -52,11 +50,12 @@ class ModelDescribe(Subcommand):
|
|||
)
|
||||
return
|
||||
|
||||
headers = [
|
||||
"Model",
|
||||
model.descriptor(),
|
||||
]
|
||||
|
||||
rows = [
|
||||
(
|
||||
colored("Model", "white", attrs=["bold"]),
|
||||
colored(model.descriptor(), "white", attrs=["bold"]),
|
||||
),
|
||||
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
|
||||
("Description", model.description),
|
||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||
|
|
@ -65,7 +64,7 @@ class ModelDescribe(Subcommand):
|
|||
]
|
||||
|
||||
if model.recommended_sampling_params is not None:
|
||||
sampling_params = model.recommended_sampling_params.dict()
|
||||
sampling_params = model.recommended_sampling_params.model_dump()
|
||||
for k in ("max_tokens", "repetition_penalty"):
|
||||
del sampling_params[k]
|
||||
rows.append(
|
||||
|
|
@ -77,5 +76,6 @@ class ModelDescribe(Subcommand):
|
|||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from llama_stack.cli.model.list import ModelList
|
|||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.remove import ModelRemove
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
|
|
@ -24,6 +25,7 @@ class ModelParser(Subcommand):
|
|||
"model",
|
||||
prog="llama model",
|
||||
description="Work with llama models",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||
|
|
@ -37,3 +39,5 @@ class ModelParser(Subcommand):
|
|||
ModelDescribe.create(subparsers)
|
||||
ModelVerifyDownload.create(subparsers)
|
||||
ModelRemove.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,14 @@
|
|||
import argparse
|
||||
import textwrap
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ModelPromptFormat(Subcommand):
|
||||
"""Llama model cli for describe a model prompt format (message formats)"""
|
||||
|
|
@ -37,8 +41,14 @@ class ModelPromptFormat(Subcommand):
|
|||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="llama3_1",
|
||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
||||
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||
"(Run `llama model list` to see a list of valid model names)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
"--list",
|
||||
action="store_true",
|
||||
help="List all available models",
|
||||
)
|
||||
|
||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||
|
|
@ -48,18 +58,42 @@ class ModelPromptFormat(Subcommand):
|
|||
supported_model_ids = [
|
||||
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
]
|
||||
model_str = "\n".join([m.value for m in supported_model_ids])
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
rows = []
|
||||
for m in model_list:
|
||||
rows.append(
|
||||
[
|
||||
m,
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
||||
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
||||
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
|
||||
if model_family(model_id) == ModelFamily.llama3_1:
|
||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||
content = f.open("r").read()
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
|
@ -65,8 +65,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
if args.image_type == "venv":
|
||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||
image_name = args.image_name or current_venv
|
||||
if not image_name and in_notebook():
|
||||
image_name = "__system__"
|
||||
elif args.image_type == "conda":
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
image_name = args.image_name or current_conda_env
|
||||
|
|
@ -143,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
completer=WordCompleter(available_providers),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in available_providers,
|
||||
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
|
||||
error_message="Invalid provider, use <TAB> to see options",
|
||||
),
|
||||
)
|
||||
|
|
@ -172,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == ImageType.container.value and not args.image_name:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
|
|
@ -215,7 +213,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||
run_with_pty(run_args)
|
||||
run_command(run_args)
|
||||
|
||||
|
||||
def _generate_run_config(
|
||||
|
|
@ -228,7 +226,7 @@ def _generate_run_config(
|
|||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
|
||||
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
|
|
@ -250,7 +248,7 @@ def _generate_run_config(
|
|||
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
if hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
|
||||
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||
else:
|
||||
config = {}
|
||||
|
||||
|
|
@ -281,16 +279,18 @@ def _run_stack_build_command_from_build_config(
|
|||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> str:
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
else:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a container image without a template")
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||
image_name = "__system__"
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a venv image")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class StackBuild(Subcommand):
|
|||
"build",
|
||||
prog="llama stack build",
|
||||
description="Build a Llama stack container",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_build_command)
|
||||
|
|
@ -26,7 +26,7 @@ class StackBuild(Subcommand):
|
|||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -9,9 +9,12 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
|
|
@ -20,7 +23,7 @@ class StackRun(Subcommand):
|
|||
"run",
|
||||
prog="llama stack run",
|
||||
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_stack_run_cmd)
|
||||
|
|
@ -34,12 +37,13 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the server on. Defaults to 8321",
|
||||
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-name",
|
||||
type=str,
|
||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||
help="Name of the image to run. Defaults to the current conda environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
|
@ -75,19 +79,10 @@ class StackRun(Subcommand):
|
|||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
BUILDS_BASE_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
)
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
|
||||
if not args.config:
|
||||
self.parser.error("Must specify a config file to run")
|
||||
return
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
|
|
@ -99,14 +94,6 @@ class StackRun(Subcommand):
|
|||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to conda dir
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to container dir
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
|
@ -115,11 +102,23 @@ class StackRun(Subcommand):
|
|||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Using run configuration: {config_file}")
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}’ is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
try:
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
except yaml.parser.ParserError as e:
|
||||
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
|
||||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
|
||||
|
||||
|
|
@ -129,20 +128,12 @@ class StackRun(Subcommand):
|
|||
|
||||
for env_var in args.env:
|
||||
if "=" not in env_var:
|
||||
cprint(
|
||||
f"Environment variable '{env_var}' must be in KEY=VALUE format",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
|
||||
key, value = env_var.split("=", 1) # split on first = only
|
||||
if not key:
|
||||
cprint(
|
||||
f"Environment variable '{env_var}' has empty key",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
self.parser.error(f"Environment variable '{env_var}' has empty key")
|
||||
run_args.extend(["--env", f"{key}={value}"])
|
||||
|
||||
if args.tls_keyfile and args.tls_certfile:
|
||||
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
||||
run_with_pty(run_args)
|
||||
run_command(run_args)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import argparse
|
||||
from importlib.metadata import version
|
||||
|
||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
from .build import StackBuild
|
||||
|
|
@ -22,6 +23,7 @@ class StackParser(Subcommand):
|
|||
"stack",
|
||||
prog="llama stack",
|
||||
description="Operations for the Llama Stack / Distributions",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
|
|
@ -39,3 +41,5 @@ class StackParser(Subcommand):
|
|||
StackListApis.create(subparsers)
|
||||
StackListProviders.create(subparsers)
|
||||
StackRun.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
|
|
|||
14
llama_stack/cli/stack/utils.py
Normal file
14
llama_stack/cli/stack/utils.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
def print_subcommand_description(parser, subparsers):
|
||||
"""Print descriptions of subcommands."""
|
||||
description_text = ""
|
||||
for name, subcommand in subparsers.choices.items():
|
||||
description = subcommand.description
|
||||
description_text += f" {name:<21} {description}\n"
|
||||
parser.epilog = description_text
|
||||
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.configure import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def up_to_date_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
version: {version}
|
||||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {built_at}
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
safety:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
image_name: foo
|
||||
built_at: {built_at}
|
||||
apis_to_serve: []
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
- provider_type: inline::meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- routing_key: ["shield1", "shield2"]
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- routing_key: vector
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
routing_table: {}
|
||||
api_providers: {}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert "inference" in result.providers
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
||||
inference_providers = result.providers["inference"]
|
||||
assert len(inference_providers) == 2
|
||||
assert set(x.provider_id for x in inference_providers) == {
|
||||
"remote::ollama-00",
|
||||
"meta-reference-01",
|
||||
}
|
||||
|
||||
ollama = inference_providers[0]
|
||||
assert ollama.provider_type == "remote::ollama"
|
||||
assert ollama.config["port"] == 11434
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||
with pytest.raises(ValueError):
|
||||
parse_and_maybe_upgrade_config(invalid_config)
|
||||
81
llama_stack/distribution/access_control.py
Normal file
81
llama_stack/distribution/access_control.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||
3. For each attribute category in the resource's access_attributes:
|
||||
a. If the user lacks that category, access is DENIED
|
||||
b. If the user has the category but none of the required values, access is DENIED
|
||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||
|
||||
Example:
|
||||
# Resource requires:
|
||||
access_attributes = AccessAttributes(
|
||||
roles=["admin", "data-scientist"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
|
||||
# User has:
|
||||
user_attributes = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "infra-team"],
|
||||
"projects": ["llama-3"]
|
||||
}
|
||||
|
||||
# Result: Access GRANTED
|
||||
# - User has the "data-scientist" role (matches one of the required roles)
|
||||
# - AND user is part of the "ml-team" (matches the required team)
|
||||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj: The resource object to check access for
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
|
||||
return True
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
for attr_key, required_values in obj_attributes.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
if not user_values:
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
|
||||
)
|
||||
return False
|
||||
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
|
||||
return True
|
||||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
@ -15,9 +14,8 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -96,18 +94,16 @@ def build_image(
|
|||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
template_or_config,
|
||||
image_name,
|
||||
container_base,
|
||||
str(build_file_path),
|
||||
str(BUILDS_BASE_DIR / ImageType.container.value),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
@ -115,7 +111,7 @@ def build_image(
|
|||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
@ -126,11 +122,7 @@ def build_image(
|
|||
if special_deps:
|
||||
args.append("#".join(special_deps))
|
||||
|
||||
is_terminal = sys.stdin.isatty()
|
||||
if is_terminal:
|
||||
return_code = run_with_pty(args)
|
||||
else:
|
||||
return_code = run_command(args)
|
||||
return_code = run_command(args)
|
||||
|
||||
if return_code != 0:
|
||||
log.error(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
|
|
@ -16,8 +16,8 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
|
|
@ -52,7 +52,7 @@ ensure_conda_env_python310() {
|
|||
local python_version="3.10"
|
||||
|
||||
# Check if conda command is available
|
||||
if ! command -v conda &>/dev/null; then
|
||||
if ! is_command_available conda; then
|
||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -87,8 +87,6 @@ ensure_conda_env_python310() {
|
|||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION \
|
||||
llama-stack-client==$TEST_PYPI_VERSION \
|
||||
llama-stack==$TEST_PYPI_VERSION \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
|
|
@ -111,22 +109,21 @@ ensure_conda_env_python310() {
|
|||
else
|
||||
PYPI_VERSION="${PYPI_VERSION:-}"
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||
else
|
||||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
uv pip install --no-cache-dir $SPEC_VERSION
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
||||
uv pip uninstall llama-models
|
||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
|
@ -6,7 +6,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
|
||||
|
|
@ -20,35 +19,38 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
if [ "$#" -lt 6 ]; then
|
||||
if [ "$#" -lt 4 ]; then
|
||||
# This only works for templates
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
template_or_config="$1"
|
||||
image_name="$2"
|
||||
container_base="$3"
|
||||
build_file_path="$4"
|
||||
host_build_dir="$5"
|
||||
pip_dependencies="$6"
|
||||
special_pip_deps="${7:-}"
|
||||
shift
|
||||
image_name="$1"
|
||||
shift
|
||||
container_base="$1"
|
||||
shift
|
||||
pip_dependencies="$1"
|
||||
shift
|
||||
special_pip_deps="${1:-}"
|
||||
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
add_to_container() {
|
||||
local input
|
||||
output_file="$TEMP_DIR/Containerfile"
|
||||
if [ -t 0 ]; then
|
||||
printf '%s\n' "$1" >>"$output_file"
|
||||
|
|
@ -58,15 +60,21 @@ add_to_container() {
|
|||
fi
|
||||
}
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Update and install UBI9 components if UBI9 base image is used
|
||||
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||
add_to_container << EOF
|
||||
FROM $container_base
|
||||
WORKDIR /app
|
||||
|
||||
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
|
||||
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
|
||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
RUN pip install uv
|
||||
|
|
@ -107,7 +115,6 @@ EOF
|
|||
fi
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
models_mount="/app/llama-models-source"
|
||||
client_mount="/app/llama-stack-client-source"
|
||||
|
||||
install_local_package() {
|
||||
|
|
@ -131,10 +138,6 @@ EOF
|
|||
}
|
||||
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
|
@ -150,12 +153,12 @@ EOF
|
|||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||
llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
EOF
|
||||
else
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||
else
|
||||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
|
|
@ -165,6 +168,11 @@ EOF
|
|||
fi
|
||||
fi
|
||||
|
||||
# remove uv after installation
|
||||
add_to_container << EOF
|
||||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||
if [[ "$template_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
|
|
@ -185,26 +193,28 @@ RUN mkdir -p /.llama /.cache
|
|||
RUN chmod -R g+rw /app /.llama /.cache
|
||||
EOF
|
||||
|
||||
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
|
||||
cat $TEMP_DIR/Containerfile
|
||||
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
|
||||
cat "$TEMP_DIR"/Containerfile
|
||||
printf "\n"
|
||||
|
||||
mounts=""
|
||||
# Start building the CLI arguments
|
||||
CLI_ARGS=()
|
||||
|
||||
# Read CONTAINER_OPTS and put it in an array
|
||||
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
|
||||
|
||||
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
|
||||
fi
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
|
||||
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
|
||||
fi
|
||||
fi
|
||||
|
||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
||||
if is_command_available selinuxenabled && selinuxenabled; then
|
||||
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||
CLI_ARGS+=("--security-opt" "label=disable")
|
||||
fi
|
||||
|
||||
# Set version tag based on PyPI version
|
||||
|
|
@ -212,7 +222,7 @@ if [ -n "$PYPI_VERSION" ]; then
|
|||
version_tag="$PYPI_VERSION"
|
||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
|
||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
|
||||
version_tag="dev"
|
||||
else
|
||||
URL="https://pypi.org/pypi/llama-stack/json"
|
||||
|
|
@ -225,11 +235,11 @@ image_tag="$image_name:$version_tag"
|
|||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
PLATFORM="--platform $BUILD_PLATFORM"
|
||||
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
PLATFORM="--platform linux/arm64"
|
||||
CLI_ARGS+=("--platform" "linux/arm64")
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
PLATFORM="--platform linux/amd64"
|
||||
CLI_ARGS+=("--platform" "linux/amd64")
|
||||
else
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
|
|
@ -238,8 +248,12 @@ fi
|
|||
echo "PWD: $(pwd)"
|
||||
echo "Containerfile: $TEMP_DIR/Containerfile"
|
||||
set -x
|
||||
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
|
||||
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain
|
||||
|
||||
$CONTAINER_BINARY build \
|
||||
"${CLI_ARGS[@]}" \
|
||||
-t "$image_tag" \
|
||||
-f "$TEMP_DIR/Containerfile" \
|
||||
"."
|
||||
|
||||
# clean up tmp/configs
|
||||
set +x
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@
|
|||
# TODO: combine this with build_conda_env.sh since it is almost identical
|
||||
# the only difference is that we don't do any conda-specific setup
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
|
|
@ -21,13 +21,13 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
|||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
echo "Usage: $0 <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
@ -95,7 +95,7 @@ run() {
|
|||
# we are building a command line so word splitting is expected
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
|
|
@ -120,15 +120,14 @@ run() {
|
|||
uv pip install --no-cache-dir llama-stack
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
|
||||
uv pip uninstall llama-models
|
||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provi
|
|||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
config=cfg.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 2 ]; then
|
||||
echo "Usage: $0 <container name> <build file path>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
container_image="$1"
|
||||
host_build_dir="$2"
|
||||
container_build_dir="/app/builds"
|
||||
|
||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
# Disable SELinux labels
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||
fi
|
||||
|
||||
mounts=""
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
||||
fi
|
||||
|
||||
set -x
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
--entrypoint "/usr/local/bin/llama" \
|
||||
-v $host_build_dir:$container_build_dir \
|
||||
$mounts \
|
||||
$container_image \
|
||||
stack configure ./llamastack-build.yaml --output-dir $container_build_dir
|
||||
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
|||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
|
|
@ -31,6 +32,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: Optional[List[str]] = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: Optional[List[str]] = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: Optional[List[str]] = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
|
||||
This class adds an optional access_attributes field that allows fine-grained control
|
||||
over which users can access each resource. When attributes are defined, a user must have
|
||||
matching attributes to access the resource.
|
||||
|
||||
Attribute Matching Algorithm:
|
||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||
|
||||
Examples:
|
||||
# Resource visible to everyone (no access control)
|
||||
model = Model(identifier="llama-2", ...)
|
||||
|
||||
# Resource visible only to admins
|
||||
model = Model(
|
||||
identifier="gpt-4",
|
||||
access_attributes=AccessAttributes(roles=["admin"])
|
||||
)
|
||||
|
||||
# Resource visible to data scientists on the ML team
|
||||
model = Model(
|
||||
identifier="private-model",
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["data-scientist", "researcher"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
)
|
||||
# ^ User must have at least one of the roles AND be on the ml-team
|
||||
|
||||
# Resource visible to users with specific project access
|
||||
vector_db = VectorDB(
|
||||
identifier="customer-embeddings",
|
||||
access_attributes=AccessAttributes(
|
||||
projects=["customer-insights"],
|
||||
namespaces=["confidential"]
|
||||
)
|
||||
)
|
||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithACL(Model, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithACL(Shield, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolWithACL(Tool, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
|
|
@ -45,14 +155,14 @@ RoutableObject = Union[
|
|||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
VectorDBWithACL,
|
||||
DatasetWithACL,
|
||||
ScoringFnWithACL,
|
||||
BenchmarkWithACL,
|
||||
ToolWithACL,
|
||||
ToolGroupWithACL,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -117,6 +227,21 @@ class Provider(BaseModel):
|
|||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: Dict[str, str] = Field(
|
||||
default_factory=Dict,
|
||||
description="""
|
||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
endpoint: str = Field(
|
||||
...,
|
||||
description="Endpoint URL to validate authentication tokens",
|
||||
)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
|
@ -132,6 +257,10 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
auth: Optional[AuthenticationConfig] = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
@ -176,6 +305,8 @@ a default SQLite store will be used.""",
|
|||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
|||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
return [v for v in Api]
|
||||
return list(Api)
|
||||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
|
|
@ -55,8 +55,8 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.inspect import (
|
||||
HealthInfo,
|
||||
Inspect,
|
||||
ListProvidersResponse,
|
||||
ListRoutesResponse,
|
||||
ProviderInfo,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
|
|
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = []
|
||||
for api, providers in run_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,10 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
|
@ -41,8 +44,10 @@ from llama_stack.distribution.stack import (
|
|||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
|
|
@ -104,7 +109,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
logger.warning(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
)
|
||||
return value
|
||||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
|
|
@ -160,6 +165,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
return sync_generator()
|
||||
|
|
@ -262,21 +270,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = {}
|
||||
if self.provider_data:
|
||||
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
|
@ -324,6 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
await end_trace()
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
|
|
@ -335,7 +348,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=options.json_data,
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
|
|
@ -373,9 +386,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
|
@ -384,7 +399,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=options.json_data,
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
66
llama_stack/distribution/providers.py
Normal file
66
llama_stack/distribution/providers.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ProviderImpl.shutdown")
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
|
@ -4,16 +4,40 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
from typing import Any, ContextManager, Dict, List, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
PROVIDER_DATA_VAR.reset(self.token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
|
@ -26,7 +50,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
||||
val = PROVIDER_DATA_VAR.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
|
@ -36,25 +60,42 @@ class NeedsRequestProviderData:
|
|||
return provider_data
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str]):
|
||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
"x-llamastack-provider-data",
|
||||
]
|
||||
val = None
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
val = json.loads(val)
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
log.error("Provider data not encoded as a JSON object!")
|
||||
return None
|
||||
|
||||
_THREAD_LOCAL.provider_data_header_value = val
|
||||
|
||||
def request_provider_data_context(
|
||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
|
|
@ -17,6 +16,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
|
@ -35,6 +35,7 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
|
|
@ -50,7 +51,7 @@ from llama_stack.providers.datatypes import (
|
|||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
|
|
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
|
|||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
|
|
@ -104,60 +106,43 @@ class ProviderWithSpec(Provider):
|
|||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
Resolves provider implementations by:
|
||||
1. Validating and organizing providers.
|
||||
2. Sorting them in dependency order.
|
||||
3. Instantiating them with required dependencies.
|
||||
"""
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
providers_with_specs = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
log.error(p.deprecation_error, "red", attrs=["bold"])
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
elif p.deprecation_warning:
|
||||
log.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.model_dump()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
providers_with_specs = validate_and_prepare_providers(
|
||||
run_config, provider_registry, routing_table_apis, router_apis
|
||||
)
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
"""Generates specifications for automatically routed APIs."""
|
||||
specs = {}
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
provider_type="__routing_table__",
|
||||
|
|
@ -167,12 +152,12 @@ async def resolve_impls(
|
|||
router_api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
deps__=([f"inner-{info.router_api.value}"]),
|
||||
deps__=[f"inner-{info.router_api.value}"],
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
provider_type="__autorouted__",
|
||||
|
|
@ -182,12 +167,68 @@ async def resolve_impls(
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
return specs
|
||||
|
||||
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
||||
|
||||
def validate_and_prepare_providers(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
||||
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
return providers_with_specs
|
||||
|
||||
|
||||
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||
"""Validates if the provider is allowed and handles deprecations."""
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logger.error(p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logger.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
||||
def sort_providers_by_deps(
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
"""Sorts providers based on their dependencies."""
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
# Append built-in "inspect" provider
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
|
|
@ -195,28 +236,51 @@ async def resolve_impls(
|
|||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
log.info(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
log.info(f" {api_str} => {provider.provider_id}")
|
||||
log.info("")
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
impls = {}
|
||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
logger.debug("")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
async def instantiate_providers(
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
||||
) -> Dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: Dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
|
|
@ -227,14 +291,9 @@ async def resolve_impls(
|
|||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
dist_registry,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
|
|
@ -245,7 +304,7 @@ async def resolve_impls(
|
|||
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[ProviderWithSpec]:
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
|
@ -261,8 +320,8 @@ def topological_sort(
|
|||
|
||||
stack.append(api_str)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
visited: Set[str] = set()
|
||||
stack: List[str] = []
|
||||
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
|
|
@ -272,13 +331,14 @@ def topological_sort(
|
|||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
deps: Dict[Api, Any],
|
||||
inner_impls: Dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
|
|
@ -286,8 +346,10 @@ async def instantiate_provider(
|
|||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
|
@ -350,7 +412,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
|
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,14 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
|
|
@ -20,6 +22,10 @@ from llama_stack.apis.eval import (
|
|||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
|
@ -27,13 +33,14 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
|
|
@ -42,6 +49,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
|
@ -51,8 +59,13 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
|
@ -62,12 +75,15 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_vector_db(
|
||||
|
|
@ -78,6 +94,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
|
@ -92,6 +109,9 @@ class VectorIORouter(VectorIO):
|
|||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
|
|
@ -100,6 +120,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
|
|
@ -109,13 +130,21 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
|
|
@ -126,13 +155,85 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
model: Model object containing model_id and provider_id
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
|
|
@ -140,7 +241,12 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
|
@ -159,8 +265,6 @@ class InferenceRouter(Inference):
|
|||
params["tool_prompt_format"] = tool_prompt_format
|
||||
tool_config = ToolConfig(**params)
|
||||
|
||||
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
||||
|
||||
tools = tools or []
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
|
|
@ -187,20 +291,67 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.chat_completion(**params)
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
|
@ -215,10 +366,41 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.completion(**params)
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -228,6 +410,7 @@ class InferenceRouter(Inference):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
|
@ -247,12 +430,15 @@ class SafetyRouter(Safety):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
|
|
@ -262,6 +448,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
|
|
@ -270,6 +457,7 @@ class SafetyRouter(Safety):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
|
|
@ -282,29 +470,51 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||
)
|
||||
await self.routing_table.register_dataset(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata=metadata,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=rows_in_page,
|
||||
page_token=page_token,
|
||||
filter_condition=filter_condition,
|
||||
start_index=start_index,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
|
|
@ -316,12 +526,15 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
|
|
@ -330,6 +543,7 @@ class ScoringRouter(Scoring):
|
|||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
|
|
@ -350,6 +564,7 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
|
|
@ -367,22 +582,26 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
task_config=task_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
async def evaluate_rows(
|
||||
|
|
@ -390,13 +609,14 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
async def job_status(
|
||||
|
|
@ -404,6 +624,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
|
|
@ -411,6 +632,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
|
@ -421,6 +643,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
|
@ -433,6 +656,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
|
|
@ -441,7 +665,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
return await self.routing_table.get_provider_impl("query_from_memory").query(
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
||||
|
|
@ -451,6 +676,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
|
|
@ -459,6 +687,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
|
|
@ -467,12 +696,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
|
|
@ -481,4 +713,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
|
@ -12,7 +13,16 @@ from pydantic import TypeAdapter
|
|||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||
from llama_stack.apis.datasets import (
|
||||
Dataset,
|
||||
DatasetPurpose,
|
||||
Datasets,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
|
|
@ -31,11 +41,22 @@ from llama_stack.apis.tools import (
|
|||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
BenchmarkWithACL,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithACL,
|
||||
ShieldWithACL,
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
VectorDBWithACL,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
|
@ -176,6 +197,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj, get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
|
|
@ -192,6 +218,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
|
|
@ -204,15 +237,24 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
async def get_model(self, model_id: str) -> Optional[Model]:
|
||||
return await self.get_object_by_identifier("model", model_id)
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
return model
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
|
@ -238,7 +280,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
|
@ -259,8 +301,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
async def list_shields(self) -> ListShieldsResponse:
|
||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
shield = await self.get_object_by_identifier("shield", identifier)
|
||||
if shield is None:
|
||||
raise ValueError(f"Shield '{identifier}' not found")
|
||||
return shield
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
|
|
@ -281,7 +326,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
shield = ShieldWithACL(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
|
@ -295,8 +340,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||
|
||||
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
|
||||
return await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
if vector_db is None:
|
||||
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
||||
return vector_db
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
@ -309,23 +357,17 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
if embedding_model == "all-MiniLM-L6-v2":
|
||||
raise ValueError(
|
||||
"Embeddings are now served via Inference providers. "
|
||||
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
||||
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
|
|
@ -338,7 +380,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
|
@ -353,39 +395,56 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError(f"Dataset '{dataset_id}' not found")
|
||||
return dataset
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if provider_dataset_id is None:
|
||||
provider_dataset_id = dataset_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this dataset
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> Dataset:
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
source = URIDataSource.parse_obj(source)
|
||||
elif source["type"] == "rows":
|
||||
source = RowsDataSource.parse_obj(source)
|
||||
|
||||
if not dataset_id:
|
||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||
|
||||
provider_dataset_id = dataset_id
|
||||
|
||||
# infer provider from source
|
||||
if source.type == DatasetType.rows.value:
|
||||
provider_id = "localfs"
|
||||
elif source.type == DatasetType.uri.value:
|
||||
# infer provider from uri
|
||||
if source.uri.startswith("huggingface"):
|
||||
provider_id = "huggingface"
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
provider_id = "localfs"
|
||||
else:
|
||||
raise ValueError(f"Unknown data source type: {source.type}")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
dataset = Dataset(
|
||||
|
||||
dataset = DatasetWithACL(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
dataset_schema=dataset_schema,
|
||||
url=url,
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.register_object(dataset)
|
||||
return dataset
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
|
|
@ -398,8 +457,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
if scoring_fn is None:
|
||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||
return scoring_fn
|
||||
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
|
|
@ -419,7 +481,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
scoring_fn = ScoringFnWithACL(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
|
|
@ -435,8 +497,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||
|
||||
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
|
||||
return await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||
if benchmark is None:
|
||||
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
||||
return benchmark
|
||||
|
||||
async def register_benchmark(
|
||||
self,
|
||||
|
|
@ -458,7 +523,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = Benchmark(
|
||||
benchmark = BenchmarkWithACL(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
|
|
@ -480,7 +545,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
|
|
@ -498,7 +566,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
|
|
@ -523,7 +591,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
|
@ -536,7 +604,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = await self.list_tools(toolgroup_id).data
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
|
|
|||
203
llama_stack/distribution/server/auth.py
Normal file
203
llama_stack/distribution/server/auth.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using an external auth endpoint.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Sends it to the configured auth endpoint along with request details
|
||||
3. Validates the response and extracts user attributes
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Authentication Request Format:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
"request": {
|
||||
"path": "/models/list",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "..."
|
||||
// All headers except Authorization
|
||||
},
|
||||
"params": {
|
||||
"limit": ["100"],
|
||||
"offset": ["0"]
|
||||
// Query parameters as key -> list of values
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Expected Auth Endpoint Response Format:
|
||||
```json
|
||||
{
|
||||
"access_attributes": { // Structured attribute format
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Optional message about auth result"
|
||||
}
|
||||
```
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth endpoint are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_endpoint):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
|
||||
api_key = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=api_key,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.auth_endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [api_key],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 401,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
|
@ -6,12 +6,9 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
|
@ -26,12 +23,14 @@ from fastapi import Path as FastapiPath
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
|
@ -39,23 +38,26 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
|
@ -117,78 +119,32 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
def handle_signal(app, signum, _) -> None:
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
Handle incoming signals and initiate a graceful shutdown of the application.
|
||||
|
||||
This function is intended to be used as a signal handler for various signals
|
||||
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
||||
indicating the received signal and initiate a shutdown process.
|
||||
|
||||
Args:
|
||||
app: The application instance containing implementations to be shut down.
|
||||
signum (int): The signal number received.
|
||||
frame: The current stack frame (not used in this function).
|
||||
|
||||
The shutdown process involves:
|
||||
- Shutting down all implementations registered in the application.
|
||||
- Gathering all running asyncio tasks.
|
||||
- Cancelling all gathered tasks.
|
||||
- Waiting for all tasks to finish.
|
||||
- Stopping the event loop.
|
||||
|
||||
Note:
|
||||
This function schedules the shutdown process as an asyncio task and does
|
||||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
||||
|
||||
# Cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
loop.stop()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(shutdown())
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
logger.info("Starting up")
|
||||
yield
|
||||
print("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
logger.info("Shutting down")
|
||||
await shutdown(app)
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
|
@ -204,15 +160,14 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
async for item in await event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
|
@ -224,18 +179,25 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
|
@ -264,7 +226,7 @@ class TracingMiddleware:
|
|||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
path = scope.get("path", "")
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
|
|
@ -348,52 +310,66 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
print(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
print(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
log_line = ""
|
||||
if args.yaml_config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
print(f"Using config file: {config_file}")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
print(f"Using template {args.template} config file: {config_file}")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
logger_config = None
|
||||
with open(config_file, "r") as fp:
|
||||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
print("Run configuration:")
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
print(yaml.dump(safe_config, indent=2))
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth and config.server.auth.endpoint:
|
||||
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError:
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig()))
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
|
|
@ -409,6 +385,7 @@ def main():
|
|||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
@ -432,15 +409,10 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||
for endpoint in endpoints:
|
||||
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
|
|
@ -462,15 +434,17 @@ def main():
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
print(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
|
|
|||
|
|
@ -5,13 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
|
@ -33,16 +33,19 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
|
|
@ -101,12 +104,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
log.info(
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
logger.debug(
|
||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
log.info("")
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
|
|
@ -155,18 +156,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
return result
|
||||
|
||||
elif isinstance(config, str):
|
||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
||||
# Updated pattern to support both default values (:) and conditional values (+)
|
||||
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
|
||||
|
||||
def get_env_var(match):
|
||||
env_var = match.group(1)
|
||||
default_val = match.group(2)
|
||||
operator = match.group(2) # ':' for default, '+' for conditional
|
||||
value_expr = match.group(3)
|
||||
|
||||
value = os.environ.get(env_var)
|
||||
if not value:
|
||||
if default_val is None:
|
||||
raise EnvVarError(env_var, path)
|
||||
env_value = os.environ.get(env_var)
|
||||
|
||||
if operator == ":": # Default value syntax: ${env.FOO:default}
|
||||
if not env_value:
|
||||
if value_expr is None:
|
||||
raise EnvVarError(env_var, path)
|
||||
else:
|
||||
value = value_expr
|
||||
else:
|
||||
value = default_val
|
||||
value = env_value
|
||||
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
|
||||
if env_value:
|
||||
value = value_expr
|
||||
else:
|
||||
# If env var is not set, return empty string for the conditional case
|
||||
value = ""
|
||||
else: # No operator case: ${env.FOO}
|
||||
if not env_value:
|
||||
raise EnvVarError(env_var, path)
|
||||
value = env_value
|
||||
|
||||
# expand "~" from the values
|
||||
return os.path.expanduser(value)
|
||||
|
|
@ -215,3 +232,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
|||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> StackRunConfig:
|
||||
"""
|
||||
Create an adhoc distribution from a list of API providers.
|
||||
|
||||
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
|
||||
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
|
||||
"""
|
||||
|
||||
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
||||
provider_registry = provider_registry or get_provider_registry()
|
||||
|
||||
distro_dir = tempfile.mkdtemp()
|
||||
provider_configs_by_api = {}
|
||||
for api_provider in api_providers:
|
||||
api_str, provider = api_provider.split("=")
|
||||
api = Api(api_str)
|
||||
|
||||
providers_by_type = provider_registry[api]
|
||||
provider_spec = providers_by_type.get(provider)
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"inline::{provider}")
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"remote::{provider}")
|
||||
|
||||
if not provider_spec:
|
||||
raise ValueError(
|
||||
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
|
||||
)
|
||||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=provider,
|
||||
provider_type=provider_spec.provider_type,
|
||||
config=provider_config,
|
||||
)
|
||||
]
|
||||
config = StackRunConfig(
|
||||
image_name="distro-test",
|
||||
apis=list(provider_configs_by_api.keys()),
|
||||
providers=provider_configs_by_api,
|
||||
)
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -98,15 +98,23 @@ case "$env_type" in
|
|||
*)
|
||||
esac
|
||||
|
||||
set -x
|
||||
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
set -x
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
# Disable SELinux labels
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||
|
|
|
|||
|
|
@ -1,72 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
venv_path="$1"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
# Initialize env_vars as an empty array
|
||||
env_vars=""
|
||||
other_args=""
|
||||
# Process environment variables from --env arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--env)
|
||||
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Using virtual environment: $venv_path"
|
||||
# Activate virtual environment
|
||||
if [ ! -d "$venv_path" ]; then
|
||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source "$venv_path/bin/activate"
|
||||
|
||||
set -x
|
||||
python -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
|
|
@ -33,7 +33,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v7"
|
||||
KEY_VERSION = "v8"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,199 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.store.registry import (
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
||||
if os.path.exists(config.db_path):
|
||||
os.remove(config.db_path)
|
||||
return config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def registry(config):
|
||||
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def cached_registry(config):
|
||||
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vector_db():
|
||||
return VectorDB(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model():
|
||||
return Model(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_initialization(registry):
|
||||
# Test empty registry
|
||||
result = await registry.get("nonexistent", "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_registration(registry, sample_vector_db, sample_model):
|
||||
print(f"Registering {sample_vector_db}")
|
||||
await registry.register(sample_vector_db)
|
||||
print(f"Registering {sample_model}")
|
||||
await registry.register(sample_model)
|
||||
print("Getting vector_db")
|
||||
result_vector_db = await registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
|
||||
result_model = await registry.get("model", "test_model")
|
||||
assert result_model is not None
|
||||
assert result_model.identifier == sample_model.identifier
|
||||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_initialization(config, sample_vector_db, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_vector_db)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_updates(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
new_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(new_vector_db)
|
||||
|
||||
# Verify in cache
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_provider_registration(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
original_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(original_vector_db)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="different-model",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
await cached_registry.register(duplicate_vector_db)
|
||||
|
||||
result = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_objects(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
# Create multiple test banks
|
||||
test_vector_dbs = [
|
||||
VectorDB(
|
||||
identifier=f"test_vector_db_{i}",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id=f"test_vector_db_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all vector_dbs
|
||||
for vector_db in test_vector_dbs:
|
||||
await cached_registry.register(vector_db)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each vector_db was stored correctly
|
||||
for original_vector_db in test_vector_dbs:
|
||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||
assert len(matching_vector_dbs) == 1
|
||||
stored_vector_db = matching_vector_dbs[0]
|
||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
11
llama_stack/distribution/ui/Containerfile
Normal file
11
llama_stack/distribution/ui/Containerfile
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# More info on playground configuration can be found here:
|
||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||
|
||||
FROM python:3.9-slim
|
||||
WORKDIR /app
|
||||
COPY . /app/
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||
/usr/local/bin/pip3 install -r requirements.txt
|
||||
EXPOSE 8501
|
||||
|
||||
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
||||
|
|
@ -17,7 +17,7 @@ llama stack run together
|
|||
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
|
||||
|
||||
```bash
|
||||
$ llama-stack-client datasets register \
|
||||
llama-stack-client datasets register \
|
||||
--dataset-id "mmlu" \
|
||||
--provider-id "huggingface" \
|
||||
--url "https://huggingface.co/datasets/llamastack/evals" \
|
||||
|
|
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
|
|||
```
|
||||
|
||||
```bash
|
||||
$ llama-stack-client benchmarks register \
|
||||
llama-stack-client benchmarks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
|
|
@ -40,3 +40,13 @@ cd llama_stack/distribution/ui
|
|||
pip install -r requirements.txt
|
||||
streamlit run app.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Environment Variable | Description | Default Value |
|
||||
|----------------------------|------------------------------------|---------------------------|
|
||||
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
|
||||
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
|
||||
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
|
||||
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
|
||||
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def datasets():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def benchmarks():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def models():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def providers():
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from page.distribution.benchmarks import benchmarks
|
||||
from page.distribution.datasets import datasets
|
||||
from page.distribution.models import models
|
||||
from page.distribution.scoring_functions import scoring_functions
|
||||
from page.distribution.shields import shields
|
||||
from page.distribution.vector_dbs import vector_dbs
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
||||
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
||||
from llama_stack.distribution.ui.page.distribution.models import models
|
||||
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.distribution.ui.page.distribution.shields import shields
|
||||
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def scoring_functions():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def shields():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def vector_dbs():
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import process_dataset
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import process_dataset
|
||||
|
||||
|
||||
def application_evaluation_page():
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import json
|
|||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def select_benchmark_1():
|
||||
|
|
@ -166,11 +167,10 @@ def run_evaluation_3():
|
|||
eval_candidate = st.session_state["eval_candidate"]
|
||||
|
||||
dataset_id = benchmarks[selected_benchmark].dataset_id
|
||||
rows = llama_stack_api.client.datasetio.get_rows_paginated(
|
||||
rows = llama_stack_api.client.datasets.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
total_rows = len(rows.rows)
|
||||
total_rows = len(rows.data)
|
||||
# Add number of examples control
|
||||
num_rows = st.number_input(
|
||||
"Number of Examples to Evaluate",
|
||||
|
|
@ -195,7 +195,7 @@ def run_evaluation_3():
|
|||
if st.button("Run Evaluation"):
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = rows.rows
|
||||
rows = rows.data
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
|
|
@ -212,7 +212,7 @@ def run_evaluation_3():
|
|||
benchmark_id=selected_benchmark,
|
||||
input_rows=[r],
|
||||
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||
task_config=benchmark_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
# Sidebar configurations
|
||||
with st.sidebar:
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@
|
|||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
from llama_stack_client.types.shared.document import Document
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||
|
||||
|
||||
def rag_chat_page():
|
||||
|
|
@ -124,26 +124,22 @@ def rag_chat_page():
|
|||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent_config = AgentConfig(
|
||||
agent = Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
toolgroups=[
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag",
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
agent = Agent(llama_stack_api.client, agent_config)
|
||||
session_id = agent.create_session("rag-session")
|
||||
|
||||
# Chat input
|
||||
|
|
|
|||
|
|
@ -13,6 +13,4 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
|||
|
||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
|
|
|||
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import AsyncGenerator, List, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
yield item
|
||||
for context_var in context_vars:
|
||||
_ = context_var.set(context_values[context_var.name])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
|
@ -4,13 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
|
@ -20,14 +17,14 @@ import importlib
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == ImageType.container.value or config.container_image:
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
elif image_type == ImageType.conda.value:
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
if not env_name:
|
||||
|
|
@ -46,7 +43,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
|||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||
envs = conda_env_info["envs"]
|
||||
for envpath in envs:
|
||||
if envpath.endswith(env_name):
|
||||
if os.path.basename(envpath) == env_name:
|
||||
return envpath
|
||||
return None
|
||||
|
||||
|
|
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
|||
return run_args
|
||||
|
||||
|
||||
def run_with_pty(command):
|
||||
if sys.platform.startswith("win"):
|
||||
return _run_with_pty_win(command)
|
||||
else:
|
||||
return _run_with_pty_unix(command)
|
||||
|
||||
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
|
@ -108,19 +98,19 @@ def in_notebook():
|
|||
return True
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal, with interrupt handling,
|
||||
# useful when you want to run interactive things
|
||||
def _run_with_pty_unix(command):
|
||||
import pty
|
||||
import termios
|
||||
def run_command(command: list[str]) -> int:
|
||||
"""
|
||||
Run a command with interrupt handling and output capture.
|
||||
Uses subprocess.run with direct stream piping for better performance.
|
||||
|
||||
master, slave = pty.openpty()
|
||||
Args:
|
||||
command (list): The command to run.
|
||||
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
Returns:
|
||||
int: The return code of the command.
|
||||
"""
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
ctrl_c_pressed = False
|
||||
process = None
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
|
|
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
|
|||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
new_settings = termios.tcgetattr(sys.stdin)
|
||||
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
|
||||
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
|
||||
|
||||
process = subprocess.Popen(
|
||||
# Run the command with stdout/stderr piped directly to system streams
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdin=slave,
|
||||
stdout=slave,
|
||||
stderr=slave,
|
||||
universal_newlines=True,
|
||||
preexec_fn=os.setsid,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Close the slave file descriptor as it's now owned by the subprocess
|
||||
os.close(slave)
|
||||
|
||||
def handle_io():
|
||||
while not ctrl_c_pressed:
|
||||
try:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
|
||||
|
||||
if sys.stdin in rlist:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data:
|
||||
break
|
||||
os.write(master, data)
|
||||
|
||||
if master in rlist:
|
||||
data = os.read(master, 1024)
|
||||
if not data:
|
||||
break
|
||||
sys.stdout.buffer.write(data)
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# This will be raised when Ctrl+C is pressed
|
||||
break
|
||||
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
handle_io()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO:
|
||||
raise
|
||||
finally:
|
||||
# Clean up
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
os.close(master)
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
return process.returncode
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal in windows, with interrupt handling,
|
||||
def _run_with_pty_win(command):
|
||||
"""
|
||||
Runs a command with interactive support using subprocess directly.
|
||||
"""
|
||||
try:
|
||||
# For shell scripts on Windows, use appropriate shell
|
||||
if isinstance(command, (list, tuple)):
|
||||
if command[0].endswith(".sh"):
|
||||
if os.path.exists("/usr/bin/bash"): # WSL
|
||||
command = ["bash"] + command
|
||||
else:
|
||||
# Use cmd.exe with bash while preserving all arguments
|
||||
command = ["cmd.exe", "/c", "bash"] + command
|
||||
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
process.wait()
|
||||
|
||||
return result.returncode
|
||||
except subprocess.SubprocessError as e:
|
||||
log.error(f"Subprocess error: {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
log.exception(f"Unexpected error: {e}")
|
||||
return 1
|
||||
finally:
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
return process.returncode
|
||||
|
||||
|
||||
def run_command(command):
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
print("Script Output\n", result.stdout)
|
||||
return result.returncode
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error running script:", e)
|
||||
print("Error output:", e.stderr)
|
||||
return e.returncode
|
||||
# Restore the original signal handler
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import enum
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
class LlamaStackImageType(enum.Enum):
|
||||
CONTAINER = "container"
|
||||
CONDA = "conda"
|
||||
VENV = "venv"
|
||||
|
|
|
|||
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
||||
243
llama_stack/log.py
Normal file
243
llama_stack/log.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
from .distribution.datatypes import LoggingConfig
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
# Predefined categories
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def config_to_category_levels(category: str, level: str):
|
||||
"""
|
||||
Helper function to be called either by environment parsing or yaml parsing to go from a list of categories and levels to a dictionary ready to be
|
||||
used by the logger dictConfig.
|
||||
|
||||
Parameters:
|
||||
category (str): logging category to apply the level to
|
||||
level (str): logging level to be used in the category
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
|
||||
category_levels: Dict[str, int] = {}
|
||||
level_value = logging._nameToLevel.get(str(level).upper())
|
||||
if level_value is None:
|
||||
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
||||
return category_levels
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
||||
"""
|
||||
Helper function to parse a yaml logging configuration found in the run.yaml
|
||||
|
||||
Parameters:
|
||||
yaml_config (Logging): the logger config object found in the run.yaml
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for category, level in yaml_config.category_levels.items():
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
||||
Parameters:
|
||||
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=120)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
"""Override emit to handle markup errors gracefully."""
|
||||
try:
|
||||
super().emit(record)
|
||||
except MarkupError:
|
||||
original_markup = self.markup
|
||||
self.markup = False
|
||||
try:
|
||||
super().emit(record)
|
||||
finally:
|
||||
self.markup = original_markup
|
||||
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
log_file (str): Path to a log file to additionally pipe the logs into
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
class CategoryFilter(logging.Filter):
|
||||
"""Ensure category is always present in log records."""
|
||||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
handlers = {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use custom console handler
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
}
|
||||
|
||||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"rich": {
|
||||
"()": logging.Formatter,
|
||||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": handlers,
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": list(handlers.keys()), # Apply all handlers
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": list(handlers.keys()),
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
dictConfig(logging_config)
|
||||
|
||||
# Ensure third-party libraries follow the root log level
|
||||
for _, logger in logging.root.manager.loggerDict.items():
|
||||
if isinstance(logger, logging.Logger):
|
||||
logger.setLevel(root_level)
|
||||
|
||||
|
||||
def get_logger(
|
||||
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None
|
||||
) -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
||||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
config (Logging): optional yaml config to override the existing logger configuration
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
if config:
|
||||
_category_levels.update(parse_yaml_config(config))
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||
|
||||
setup_logging(_category_levels, log_file)
|
||||
|
|
@ -11,16 +11,135 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
# import all for backwards compatibility
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# The goal is that these set of types are relevant for all Llama models.
|
||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||
# the llama3 series of models.
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
system = "system"
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
tool = "tool"
|
||||
|
||||
|
||||
class BuiltinTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
|
||||
Primitive = Union[str, int, float, bool, None]
|
||||
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: Union[str, Dict[str, RecursiveType]]
|
||||
arguments_json: Optional[str] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class ToolPromptFormat(Enum):
|
||||
"""Prompt format for calling custom / zero shot tools.
|
||||
|
||||
:cvar json: JSON format for calling tools. It takes the form:
|
||||
{
|
||||
"type": "function",
|
||||
"function" : {
|
||||
"name": "function_name",
|
||||
"description": "function_description",
|
||||
"parameters": {...}
|
||||
}
|
||||
}
|
||||
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
|
||||
<function=function_name>(parameters)</function>
|
||||
|
||||
:cvar python_list: Python list. The output is a valid Python expression that can be
|
||||
evaluated to a list. Each element in the list is a function call. Example:
|
||||
["function_name(param1, param2)", "function_name(param1, param2)"]
|
||||
"""
|
||||
|
||||
json = "json"
|
||||
function_tag = "function_tag"
|
||||
python_list = "python_list"
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class RawMediaItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes | BytesIO
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("data")
|
||||
def serialize_data(self, data: Optional[bytes], _info):
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def validate_data(cls, v):
|
||||
if isinstance(v, str):
|
||||
return base64.b64decode(v)
|
||||
return v
|
||||
|
||||
|
||||
class RawTextItem(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
||||
|
||||
RawContent = str | RawContentItem | List[RawContentItem]
|
||||
|
||||
|
||||
class RawMessage(BaseModel):
|
||||
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
|
||||
content: RawContent
|
||||
|
||||
# This is for RAG but likely should be absorbed into content
|
||||
context: Optional[RawContent] = None
|
||||
|
||||
# These are for the output message coming from the assistant
|
||||
stop_reason: Optional[StopReason] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema(ToolCall)
|
||||
|
||||
|
||||
|
|
@ -67,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
|
|||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = register_schema(
|
||||
Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="SamplingStrategy",
|
||||
)
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
285
llama_stack/models/llama/llama3/chat_format.py
Normal file
285
llama_stack/models/llama/llama3/chat_format.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .tool_utils import ToolUtils
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionInput:
|
||||
mask: List[List[int]]
|
||||
images: List[PIL_Image.Image]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
tokens: List[int]
|
||||
vision: Optional[VisionInput] = None
|
||||
|
||||
|
||||
def role_str(role: Role) -> str:
|
||||
role_strs = {
|
||||
Role.user: "user",
|
||||
Role.system: "system",
|
||||
Role.tool: "ipython", # special
|
||||
Role.assistant: "assistant",
|
||||
}
|
||||
return role_strs[role]
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
|
||||
def __init__(self, tokenizer: Tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
|
||||
def encode_content(self, content: RawContent) -> LLMInput:
|
||||
tokens, images = self._encode_content(content, bos=True)
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
tokens = []
|
||||
images = []
|
||||
|
||||
added_bos = False
|
||||
|
||||
def _process(c):
|
||||
nonlocal added_bos, bos
|
||||
|
||||
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||
if isinstance(c, RawTextItem):
|
||||
c = c.text
|
||||
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||
added_bos = True
|
||||
|
||||
elif isinstance(c, RawMediaItem):
|
||||
bos = False if added_bos else bos
|
||||
if bos:
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
added_bos = True
|
||||
tokens.append(self.vision_token)
|
||||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = image.convert("RGB")
|
||||
images.append(image)
|
||||
|
||||
if isinstance(content, list):
|
||||
for c in content:
|
||||
_process(c)
|
||||
else:
|
||||
_process(content)
|
||||
|
||||
return tokens, images
|
||||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
def _process_content(c):
|
||||
toks, imgs = self._encode_content(c)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
if (
|
||||
message.role == "assistant"
|
||||
and len(message.tool_calls) > 0
|
||||
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
|
||||
):
|
||||
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
|
||||
|
||||
_process_content(message.content)
|
||||
|
||||
if message.role == "user" and message.context is not None:
|
||||
# This is RAG context; why is it here in the chat format? I don't think
|
||||
# this is needed and can be moved upwards
|
||||
_process_content("\n\n")
|
||||
_process_content(message.context)
|
||||
|
||||
if message.role == "assistant":
|
||||
for t in message.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||
_process_content(content)
|
||||
|
||||
eom = False
|
||||
if message.role == "assistant":
|
||||
eom = message.stop_reason == StopReason.end_of_message
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
|
||||
return tokens, images
|
||||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> LLMInput:
|
||||
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||
tokens = []
|
||||
images = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
for message in messages:
|
||||
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
# Add the start of an assistant message for the model to complete.
|
||||
tokens.extend(self._encode_header("assistant"))
|
||||
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||
content = content.strip(" ")
|
||||
header_str = self.possible_headers[Role.assistant]
|
||||
if content.startswith(header_str):
|
||||
content = content[len(header_str) :]
|
||||
|
||||
ipython = content.startswith("<|python_tag|>")
|
||||
if ipython:
|
||||
content = content[len("<|python_tag|>") :]
|
||||
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_arguments = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
if custom_tool_info is not None:
|
||||
tool_name, tool_arguments = custom_tool_info
|
||||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
if isinstance(tool_arguments, dict):
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
tool_name, query = builtin_tool_info
|
||||
tool_arguments = {
|
||||
"query": query,
|
||||
}
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
elif ipython:
|
||||
tool_name = BuiltinTool.code_interpreter
|
||||
tool_arguments = {
|
||||
"code": content,
|
||||
}
|
||||
|
||||
tool_calls = []
|
||||
if tool_name is not None and tool_arguments is not None:
|
||||
call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
||||
vision_input = None
|
||||
if len(images) > 0:
|
||||
vision_input = VisionInput(
|
||||
mask=create_vision_mask(tokens, self.vision_token),
|
||||
images=images,
|
||||
)
|
||||
|
||||
return LLMInput(
|
||||
tokens=[128256 if token == self.vision_token else token for token in tokens],
|
||||
vision=vision_input,
|
||||
)
|
||||
|
||||
|
||||
def create_vision_mask(
|
||||
tokens: List[int],
|
||||
vision_token: int,
|
||||
) -> List[List[int]]:
|
||||
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||
if len(vision_token_locations) == 0:
|
||||
return []
|
||||
|
||||
if len(vision_token_locations) == 1:
|
||||
# only one image present, unmask until end of sequence
|
||||
return [[vision_token_locations[0], -1]]
|
||||
vision_masks = [
|
||||
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
|
||||
]
|
||||
# last image will attend to all subsequent text
|
||||
vision_masks.append([vision_token_locations[-1], len(tokens)])
|
||||
|
||||
# if there are two or more consecutive vision tokens,
|
||||
# they should all attend to all subsequent
|
||||
# text present
|
||||
last_mask_end = vision_masks[-1][1]
|
||||
for vision_mask in vision_masks[::-1]:
|
||||
if vision_mask[0] == vision_mask[1] - 1:
|
||||
vision_mask[1] = last_mask_end
|
||||
last_mask_end = vision_mask[1]
|
||||
return vision_masks
|
||||
|
|
@ -14,20 +14,19 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
|
|
@ -35,6 +34,7 @@ from .prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
ToolResponseGenerator,
|
||||
)
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,8 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
)
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
|
|
@ -37,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
|||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"today": datetime.now().strftime("%d %B %Y")},
|
||||
{
|
||||
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
|
||||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
|
|
@ -226,10 +225,9 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||
DEFAULT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||
|
||||
{{ function_description }}
|
||||
""".strip("\n")
|
||||
|
|
@ -246,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
template_str = textwrap.dedent(
|
||||
"""
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
|
|
|||
|
|
@ -11,11 +11,8 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
|
|
|
|||
|
|
@ -1,199 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator, expected_text):
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
self.check_generator_output(generator, expected_text)
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]"""
|
||||
)
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
File diff suppressed because it is too large
Load diff
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 256
|
||||
|
||||
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|finetune_right_pad_id|>",
|
||||
"<|step_id|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eom_id|>", # end of message
|
||||
"<|eot_id|>", # end of turn
|
||||
"<|python_tag|>",
|
||||
"<|image|>",
|
||||
]
|
||||
reserved_tokens = [
|
||||
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||
]
|
||||
special_tokens = special_tokens + reserved_tokens
|
||||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=self.pat_str,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
|
||||
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||
self.eot_id: int = self.special_tokens["<|eot_id|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom_id|>"]
|
||||
self.python_tag_id = self.special_tokens["<|python_tag|>"]
|
||||
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom_id|>"],
|
||||
self.special_tokens["<|eot_id|>"],
|
||||
]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
s: str,
|
||||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
s (str): The input string to be encoded.
|
||||
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||
eos (bool): Whether to append the end-of-sequence token.
|
||||
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
|
||||
By default, setting disallowed_special=() encodes a string by ignoring
|
||||
special tokens. Specifically:
|
||||
- Setting `disallowed_special` to () will cause all text corresponding
|
||||
to special tokens to be encoded as natural text (insteading of raising
|
||||
an error).
|
||||
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||
to special tokens to be encoded as special tokens.
|
||||
"""
|
||||
if allowed_special is None:
|
||||
allowed_special = set()
|
||||
assert type(s) is str
|
||||
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
if bos:
|
||||
t.insert(0, self.bos_id)
|
||||
if eos:
|
||||
t.append(self.eos_id)
|
||||
return t
|
||||
|
||||
def decode(self, t: Sequence[int]) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
t (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
||||
210
llama_stack/models/llama/llama3/tool_utils.py
Normal file
210
llama_stack/models/llama/llama3/tool_utils.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
||||
|
||||
def is_json(s):
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
# Return True for valid objects and not for ints, strings, etc
|
||||
return isinstance(parsed, dict)
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_python_list(input_string):
|
||||
"""Check if the input string is a valid Python list of function calls"""
|
||||
try:
|
||||
# Try to parse the string
|
||||
tree = ast.parse(input_string)
|
||||
|
||||
# Check if it's a single expression
|
||||
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
||||
return False
|
||||
|
||||
# Check if the expression is a list
|
||||
expr = tree.body[0].value
|
||||
if not isinstance(expr, ast.List):
|
||||
return False
|
||||
|
||||
# Check if the list is empty
|
||||
if len(expr.elts) == 0:
|
||||
return False
|
||||
|
||||
# Check if all elements in the list are function calls
|
||||
for element in expr.elts:
|
||||
if not isinstance(element, ast.Call):
|
||||
return False
|
||||
|
||||
# Check if the function call has a valid name
|
||||
if not isinstance(element.func, ast.Name):
|
||||
return False
|
||||
|
||||
# Check if all arguments are keyword arguments
|
||||
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except SyntaxError:
|
||||
# If parsing fails, it's not a valid Python expression
|
||||
return False
|
||||
|
||||
|
||||
def parse_python_list_for_function_calls(input_string):
|
||||
"""
|
||||
Parse a Python list of function calls and
|
||||
return a list of tuples containing the function name and arguments
|
||||
"""
|
||||
# Parse the string into an AST
|
||||
tree = ast.parse(input_string)
|
||||
|
||||
# Ensure the input is a list
|
||||
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
||||
raise ValueError("Input must be a list of function calls")
|
||||
|
||||
result = []
|
||||
|
||||
# Iterate through each function call in the list
|
||||
for node in tree.body[0].value.elts:
|
||||
if isinstance(node, ast.Call):
|
||||
function_name = node.func.id
|
||||
function_args = {}
|
||||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
try:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
|
||||
) from e
|
||||
|
||||
result.append((function_name, function_args))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ToolUtils:
|
||||
@staticmethod
|
||||
def is_builtin_tool_call(message_body: str) -> bool:
|
||||
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||
return match is not None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
# Find the first match in the text
|
||||
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||
|
||||
# Check if a match is found and return it
|
||||
if match:
|
||||
tool_name = match.group("tool_name")
|
||||
query = match.group("query")
|
||||
return tool_name, query
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
# NOTE: Custom function too calls are still experimental
|
||||
# Sometimes, response is of the form
|
||||
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||
# and some times
|
||||
# <function=function_name>(parameters)</function>
|
||||
|
||||
# Find the first match in the text
|
||||
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
||||
if match:
|
||||
tool_name = match.group("function_name")
|
||||
query = match.group("args")
|
||||
try:
|
||||
return tool_name, json.loads(query.replace("'", '"'))
|
||||
except Exception as e:
|
||||
print("Exception while parsing json query for custom tool call", query, e)
|
||||
return None
|
||||
elif is_json(message_body):
|
||||
response = json.loads(message_body)
|
||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
else:
|
||||
return None
|
||||
elif is_valid_python_list(message_body):
|
||||
res = parse_python_list_for_function_calls(message_body)
|
||||
# FIXME: Enable multiple tool calls
|
||||
return res[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
q = t.arguments["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
else:
|
||||
fname = t.tool_name
|
||||
|
||||
if tool_prompt_format == ToolPromptFormat.json:
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fname,
|
||||
"parameters": t.arguments,
|
||||
}
|
||||
)
|
||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
args = json.dumps(t.arguments)
|
||||
return f"<function={fname}>{args}</function>"
|
||||
|
||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||
|
||||
def format_value(value: RecursiveType) -> str:
|
||||
if isinstance(value, str):
|
||||
return f'"{value}"'
|
||||
elif isinstance(value, (int, float, bool)) or value is None:
|
||||
return str(value)
|
||||
elif isinstance(value, list):
|
||||
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||
elif isinstance(value, dict):
|
||||
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(value)}")
|
||||
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
|
||||
|
||||
# Llama 3.1 - Prompt Formats
|
||||
## Tokens
|
||||
Here is a list of special tokens that are supported by Llama 3.1:
|
||||
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and ipython]
|
||||
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||
- at the end of a direct interaction between the model and the user
|
||||
- at the end of multiple interactions between the model and any available tools
|
||||
This token signals to the executor that the model has finished generating a response.
|
||||
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||
|
||||
|
||||
|
||||
There are 4 different roles that are supported by Llama 3.1
|
||||
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||
- `ipython`: A new role introduced in Llama 3.1. Semantically, this role means "tool". This role is used to mark messages with the output of a tool call when sent back to the model from the executor.
|
||||
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `ipython` and `user` prompts.
|
||||
|
||||
## Llama 3.1 Base Model
|
||||
|
||||
Text completion for Llama 3.1 base model uses this format.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|>Color of sky is blue but sometimes can also be
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
red, orange, yellow, green, purple, pink, brown, gray, black, white, and even rainbow colors. The color of the sky can change due to various reasons such as time of day, weather conditions, pollution, and atmospheric phenomena.
|
||||
The color of the sky is primarily blue because of a phenomenon called
|
||||
```
|
||||
|
||||
|
||||
|
||||
Note start special tag
|
||||
|
||||
|
||||
## Llama 3.1 Instruct Model
|
||||
## User and assistant conversation
|
||||
|
||||
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Answer who are you in the form of jeopardy?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
Here's my response
|
||||
|
||||
"What is a helpful assistant?"<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Tool Calling Formats
|
||||
|
||||
|
||||
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||
- Brave Search: Tool call to perform web searches.
|
||||
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||
- Code Interpreter: Enables the model to output python code.
|
||||
|
||||
## Builtin Tool Calling
|
||||
|
||||
|
||||
Here is an example of a conversation using brave search
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 21 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||
|
||||
|
||||
## Builtin Code Interpreter
|
||||
|
||||
Here is an actual example of model responding with code
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>def is_prime(n):
|
||||
if n <= 1
|
||||
return False
|
||||
for i in range(2, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
print(is_prime(7)) # Output: True<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||
|
||||
|
||||
## Built-in tools full interaction
|
||||
|
||||
Here is a full interaction with the built-in tools including the tool response and the final assistant response.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What is the 100th decimal of pi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
<|python_tag|>wolfram_alpha.call(query="100th decimal of pi")<|eom_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
|
||||
{
|
||||
"queryresult": {
|
||||
"success": true,
|
||||
"inputstring": "100th decimal of pi",
|
||||
"pods": [
|
||||
{
|
||||
"title": "Input interpretation",
|
||||
"subpods": [
|
||||
{
|
||||
"title": "",
|
||||
"plaintext": "100th digit | π"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Nearby digits",
|
||||
"subpods": [
|
||||
{
|
||||
"title": "",
|
||||
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Result",
|
||||
"primary": true,
|
||||
"subpods": [
|
||||
{
|
||||
"title": "",
|
||||
"plaintext": "7"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
The 100th decimal of pi is 7.<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- Note the `<|python_tag|>` in the assistant response.
|
||||
- Role is `ipython` for the wolfram alpha response that is passed back to the model.
|
||||
- Final message from assistant has <|eot_id|> tag.
|
||||
|
||||
|
||||
|
||||
## Zero shot tool calling
|
||||
## JSON based tool calling
|
||||
|
||||
|
||||
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 21 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>{
|
||||
"type": "function",
|
||||
"name": "trending_songs",
|
||||
"parameters": {
|
||||
"n": "10",
|
||||
"genre": "all"
|
||||
}
|
||||
}<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- JSON format for providing tools needs name, description and parameters
|
||||
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||
- Instructions for tools added as a user message
|
||||
- Only single tool calls are supported as of now
|
||||
|
||||
|
||||
|
||||
## Example of a user defined tool calling
|
||||
## `<function>` based tool calling
|
||||
|
||||
|
||||
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 21 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<function=trending_songs>{"n": 10}</function><|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||
- Instructions for tools added as a user message
|
||||
|
||||
|
||||
Thank You!
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
|
|
|
|||
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
## User and assistant conversation
|
||||
|
||||
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
I'm an AI assistant, which means I'm a computer program designed to simulate conversation and answer questions to the best of my ability. I'm here to help you with any questions or tasks you may have, and I'll do my best to provide you with accurate and helpful information.
|
||||
|
||||
I don't have a personal name, but I'm often referred to as a "virtual assistant" or a "chatbot." I'm a machine learning model, which means I've been trained on a large dataset of text and can generate responses based on patterns and context.
|
||||
|
||||
I can help with a wide range of topics, from general knowledge and trivia to more specialized subjects like science, history, and technology. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or poetry.
|
||||
|
||||
So, what can I help you with today?<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
This format is unchanged from Llama3.1
|
||||
|
||||
## Zero shot function calling
|
||||
|
||||
|
||||
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
|
||||
This new format is designed to be more flexible and powerful than the previous format.
|
||||
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
|
||||
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
|
||||
Here is an example for the same,
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What is the weather in SF and Seattle?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')]<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The output supports multiple tool calls natively
|
||||
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||
|
||||
|
||||
## Zero shot function calling with user message
|
||||
|
||||
|
||||
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||
Here is a list of functions in JSON format that you can invoke:
|
||||
[
|
||||
{
|
||||
"name": "get_user_info",
|
||||
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"user_id"
|
||||
],
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "integer",
|
||||
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||
},
|
||||
"special": {
|
||||
"type": "string",
|
||||
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||
"default": "none"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||
|
||||
NO other text MUST be included.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
[get_user_info(user_id=7890, special='black')]<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
|
||||
|
||||
|
||||
## Code Interpreter
|
||||
|
||||
|
||||
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
|
||||
Here is an example,
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 24 September 2024
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Write code to check if number is prime. Use it to verify if number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>def is_prime(n):
|
||||
if n <= 1:
|
||||
return False
|
||||
if n == 2:
|
||||
return True
|
||||
if n % 2 == 0:
|
||||
return False
|
||||
max_divisor = int(n**0.5) + 1
|
||||
for d in range(3, max_divisor, 2):
|
||||
if n % d == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
print(is_prime(7)) # Output: True<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- Note `Environment: ipython` in the system prompt.
|
||||
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
|
||||
|
||||
|
||||
## Zero shot function calling E2E format
|
||||
|
||||
|
||||
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What is the weather in SF?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
<|python_tag|>[get_weather(city="San Francisco", metric="celsius")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
"25 C"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
The weather in San Francisco is 25 C.<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The output of the function call is provided back to the model as a tool response ( in json format ).
|
||||
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
|
||||
- The model finally summarizes the information from the tool response and returns the result to the user.
|
||||
|
||||
|
||||
## Prompt format for base models
|
||||
|
||||
|
||||
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
gray or even purple. The color of the sky can change depending on the time of day, the weather, and the amount of pollution in the air. The color of the sky can also be affected by the presence of dust, smoke, and other particles in the air.
|
||||
|
||||
## Step 1: Identify the factors that
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
Same as Llama3.1
|
||||
|
||||
Thank You!
|
||||
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
## User and assistant conversation
|
||||
|
||||
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
I'm a helpful assistant, here to provide information, answer questions, and assist with tasks to the best of my abilities. I'm a large language model, which means I can understand and respond to natural language inputs, and I'm constantly learning and improving to provide more accurate and helpful responses.
|
||||
|
||||
I can help with a wide range of topics, from general knowledge and trivia to more specific areas like science, history, technology, and more. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or dialogues.
|
||||
|
||||
What can I help you with today?<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
This format is unchanged from Llama3.1
|
||||
|
||||
## User and assistant conversation with Images
|
||||
|
||||
This example shows how to pass and image to the model as part of the messages.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
<|image|>Describe this image in two sentences<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
The image depicts a small dog standing on a skateboard, with its front paws firmly planted on the board and its back paws slightly raised. The dog's fur is predominantly brown and white, with a distinctive black stripe running down its back, and it is wearing a black collar around its neck.<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The `<|image|>` tag is used to indicate presence of the image
|
||||
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
|
||||

|
||||
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
|
||||
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
|
||||
- We recommend using a single image in one prompt
|
||||
|
||||
|
||||
## Builtin and Zero Shot Tool Calling
|
||||
|
||||
|
||||
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
|
||||
Use `Environment: ipython` to enable tools.
|
||||
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
|
||||
The same builtin tools as Llama3.1 are available,
|
||||
- code_interpreter (for executing python code)
|
||||
- brave_search (to search the web)
|
||||
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 23 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- Note the `<|python_tag|>` before `brave_search` function call.
|
||||
- The `<|eom_id|>` tag is used to indicate the end of the message.
|
||||
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
|
||||
- Tool Calling does NOT work with images in the prompt as of now.
|
||||
|
||||
|
||||
## Prompt format for base models
|
||||
|
||||
|
||||
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
red, orange, pink, purple, and even black. The color of the sky is determined by the amount of sunlight that is scattered by the atmosphere and the amount of dust and water vapor present in the atmosphere. During sunrise and sunset, the sky can take on a range of colors due to the scattering of light by
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
- Same as Llama3.1
|
||||
|
||||
## Prompt format for base models with Image
|
||||
|
||||
|
||||
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|image|>If I had to write a haiku for this one
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
, it would be: A skateboarder's delight, a puppy on a board, a furry little thrill-seeker. This puppy is a true skateboarding enthusiast, always eager to hit the streets and show off his skills. He's a master of the board, gliding effortlessly across the pavement with grace and style.
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
- Note the placement of the special tags <|begin_of_text|> and <|image|>
|
||||
|
||||
Thank You!
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ import textwrap
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
|
|
@ -25,7 +27,6 @@ from llama_models.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -12,6 +12,7 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
|
|
@ -21,12 +22,15 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
|
@ -83,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
|
|
@ -119,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
|
|
@ -140,7 +144,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
|
@ -150,7 +153,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
allow_turn_resume=allow_turn_resume,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
@ -161,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
|
|
@ -170,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
|
@ -189,22 +191,18 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
steps = turn.steps
|
||||
for step in steps:
|
||||
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
||||
for step in turn.steps:
|
||||
if step.step_id == step_id:
|
||||
return AgentStepResponse(step=step)
|
||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||
|
|
@ -215,20 +213,18 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
||||
session = Session(**json.loads(session), turns=[])
|
||||
turns = []
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
for turn_id in turn_ids:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
turns.append(turn)
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
return Session(
|
||||
session_name=session.session_name,
|
||||
session_name=session_info.session_name,
|
||||
session_id=session_id,
|
||||
turns=turns if turns else [],
|
||||
started_at=session.started_at,
|
||||
turns=turns,
|
||||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
|
|
@ -239,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
|
|||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
|
||||
|
|
@ -35,7 +36,7 @@ class AgentPersistence:
|
|||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
@ -85,6 +86,14 @@ class AgentPersistence:
|
|||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
return Turn(**json.loads(value))
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
|
|
@ -96,3 +105,15 @@ class AgentPersistence:
|
|||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue