Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -1,7 +1,8 @@
[flake8]
# Suggested config from pytorch that we can adapt # Suggested config from pytorch that we can adapt
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2 lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"]
max-line-length = 120
line-length = 120
# C408 ignored because we like the dict keyword argument syntax # C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead # E501 is not flexible enough, we're using B950 instead
# N812 ignored because import torch.nn.functional as F is PyTorch convention # N812 ignored because import torch.nn.functional as F is PyTorch convention
@ -9,23 +10,28 @@ max-line-length = 120
# E731 allow usage of assigning lambda expressions # E731 allow usage of assigning lambda expressions
# E701 let black auto-format statements on one line # E701 let black auto-format statements on one line
# E704 let black auto-format statements on one line # E704 let black auto-format statements on one line
ignore = lint.ignore = [
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,E701,E704 "E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841",
"C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701",
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023",
# shebang has extra meaning in fbcode lints, so I think it's not worth trying # shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit # to line this up with executable bit
EXE001, "EXE001",
# random naming hints don't need # random naming hints don't need
N802, "N802",
# these ignores are from flake8-bugbear; please fix! # these ignores are from flake8-bugbear; please fix!
B007,B008,B950 "B007", "B008"
optional-ascii-coding = True ]
exclude =
./.git, exclude = [
./docs/*, "./.git",
./build, "./docs/*",
./scripts, "./build",
./venv, "./scripts",
*.pyi, "./venv",
.pre-commit-config.yaml, "*.pyi",
*.md, ".pre-commit-config.yaml",
.flake8 "*.md",
".flake8"
]

View file

@ -77,7 +77,7 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
# Enable both RAG and tool usage # Enable both RAG and tool usage
toolgroups=[ toolgroups=[
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}}. {"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
# Configure safety # Configure safety
@ -86,13 +86,9 @@ agent_config = AgentConfig(
# Control the inference loop # Control the inference loop
max_infer_iters=5, max_infer_iters=5,
sampling_params={ sampling_params={
"strategy": { "strategy": {"type": "top_p", "temperature": 0.7, "top_p": 0.95},
"type": "top_p", "max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.95
}, },
"max_tokens": 2048
}
) )
agent = Agent(client, agent_config) agent = Agent(client, agent_config)
@ -101,11 +97,13 @@ session_id = agent.create_session("monitored_session")
# Stream the agent's execution steps # Stream the agent's execution steps
response = agent.create_turn( response = agent.create_turn(
messages=[{"role": "user", "content": "Analyze this code and run it"}], messages=[{"role": "user", "content": "Analyze this code and run it"}],
attachments=[{ attachments=[
{
"content": "https://raw.githubusercontent.com/example/code.py", "content": "https://raw.githubusercontent.com/example/code.py",
"mime_type": "text/plain" "mime_type": "text/plain",
}], }
session_id=session_id ],
session_id=session_id,
) )
# Monitor each step of execution # Monitor each step of execution

View file

@ -15,6 +15,7 @@ This first example walks you through how to evaluate a model candidate served by
```python ```python
import datasets import datasets
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev") ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"]) ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
eval_rows = ds.to_pandas().to_dict(orient="records") eval_rows = ds.to_pandas().to_dict(orient="records")
@ -43,7 +44,7 @@ system_message = {
client.eval_tasks.register( client.eval_tasks.register(
eval_task_id="meta-reference::mmmu", eval_task_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}", dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"] scoring_functions=["basic::regex_parser_multiple_choice_answer"],
) )
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
@ -62,9 +63,9 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096, "max_tokens": 4096,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
"system_message": system_message "system_message": system_message,
} },
} },
) )
``` ```
@ -88,7 +89,7 @@ _ = client.datasets.register(
"input_query": {"type": "string"}, "input_query": {"type": "string"},
"expected_answer": {"type": "string"}, "expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"}, "chat_completion_input": {"type": "chat_completion_input"},
} },
) )
eval_rows = client.datasetio.get_rows_paginated( eval_rows = client.datasetio.get_rows_paginated(
@ -101,7 +102,7 @@ eval_rows = client.datasetio.get_rows_paginated(
client.eval_tasks.register( client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa", eval_task_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id, dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"] scoring_functions=["llm-as-judge::405b-simpleqa"],
) )
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
@ -120,8 +121,8 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096, "max_tokens": 4096,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
} },
} },
) )
``` ```
@ -144,14 +145,14 @@ agent_config = {
{ {
"type": "brave_search", "type": "brave_search",
"engine": "tavily", "engine": "tavily",
"api_key": userdata.get("TAVILY_SEARCH_API_KEY") "api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
} }
], ],
"tool_choice": "auto", "tool_choice": "auto",
"tool_prompt_format": "json", "tool_prompt_format": "json",
"input_shields": [], "input_shields": [],
"output_shields": [], "output_shields": [],
"enable_session_persistence": False "enable_session_persistence": False,
} }
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
@ -163,7 +164,7 @@ response = client.eval.evaluate_rows(
"eval_candidate": { "eval_candidate": {
"type": "agent", "type": "agent",
"config": agent_config, "config": agent_config,
} },
} },
) )
``` ```

View file

@ -13,7 +13,7 @@ Here's how to set up basic evaluation:
response = client.eval_tasks.register( response = client.eval_tasks.register(
eval_task_id="my_eval", eval_task_id="my_eval",
dataset_id="my_dataset", dataset_id="my_dataset",
scoring_functions=["accuracy", "relevance"] scoring_functions=["accuracy", "relevance"],
) )
# Run evaluation # Run evaluation
@ -21,16 +21,10 @@ job = client.eval.run_eval(
task_id="my_eval", task_id="my_eval",
task_config={ task_config={
"type": "app", "type": "app",
"eval_candidate": { "eval_candidate": {"type": "agent", "config": agent_config},
"type": "agent", },
"config": agent_config
}
}
) )
# Get results # Get results
result = client.eval.job_result( result = client.eval.job_result(task_id="my_eval", job_id=job.job_id)
task_id="my_eval",
job_id=job.job_id
)
``` ```

View file

@ -34,15 +34,16 @@ chunks = [
{ {
"document_id": "doc1", "document_id": "doc1",
"content": "Your document text here", "content": "Your document text here",
"mime_type": "text/plain" "mime_type": "text/plain",
}, },
... ...,
] ]
client.vector_io.insert(vector_db_id, chunks) client.vector_io.insert(vector_db_id, chunks)
# You can then query for these chunks # You can then query for these chunks
chunks_response = client.vector_io.query(vector_db_id, query="What do you know about...") chunks_response = client.vector_io.query(
vector_db_id, query="What do you know about..."
)
``` ```
### Using the RAG Tool ### Using the RAG Tool
@ -81,7 +82,6 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python ```python
# Configure agent with memory # Configure agent with memory
agent_config = AgentConfig( agent_config = AgentConfig(
model="Llama3.2-3B-Instruct", model="Llama3.2-3B-Instruct",
@ -91,9 +91,9 @@ agent_config = AgentConfig(
"name": "builtin::rag", "name": "builtin::rag",
"args": { "args": {
"vector_db_ids": [vector_db_id], "vector_db_ids": [vector_db_id],
},
} }
} ],
]
) )
agent = Agent(client, agent_config) agent = Agent(client, agent_config)
@ -101,25 +101,21 @@ session_id = agent.create_session("rag_session")
# Initial document ingestion # Initial document ingestion
response = agent.create_turn( response = agent.create_turn(
messages=[{ messages=[
"role": "user", {"role": "user", "content": "I am providing some documents for reference."}
"content": "I am providing some documents for reference." ],
}],
documents=[ documents=[
dict( dict(
content="https://raw.githubusercontent.com/example/doc.rst", content="https://raw.githubusercontent.com/example/doc.rst",
mime_type="text/plain" mime_type="text/plain",
) )
], ],
session_id=session_id session_id=session_id,
) )
# Query with RAG # Query with RAG
response = agent.create_turn( response = agent.create_turn(
messages=[{ messages=[{"role": "user", "content": "What are the key topics in the documents?"}],
"role": "user", session_id=session_id,
"content": "What are the key topics in the documents?"
}],
session_id=session_id
) )
``` ```

View file

@ -5,15 +5,11 @@ Safety is a critical component of any AI application. Llama Stack provides a Shi
```python ```python
# Register a safety shield # Register a safety shield
shield_id = "content_safety" shield_id = "content_safety"
client.shields.register( client.shields.register(shield_id=shield_id, provider_shield_id="llama-guard-basic")
shield_id=shield_id,
provider_shield_id="llama-guard-basic"
)
# Run content through shield # Run content through shield
response = client.safety.run_shield( response = client.safety.run_shield(
shield_id=shield_id, shield_id=shield_id, messages=[{"role": "user", "content": "User message here"}]
messages=[{"role": "user", "content": "User message here"}]
) )
if response.violation: if response.violation:

View file

@ -8,24 +8,16 @@ The telemetry system supports three main types of events:
- **Unstructured Log Events**: Free-form log messages with severity levels - **Unstructured Log Events**: Free-form log messages with severity levels
```python ```python
unstructured_log_event = UnstructuredLogEvent( unstructured_log_event = UnstructuredLogEvent(
message="This is a log message", message="This is a log message", severity=LogSeverity.INFO
severity=LogSeverity.INFO
) )
``` ```
- **Metric Events**: Numerical measurements with units - **Metric Events**: Numerical measurements with units
```python ```python
metric_event = MetricEvent( metric_event = MetricEvent(metric="my_metric", value=10, unit="count")
metric="my_metric",
value=10,
unit="count"
)
``` ```
- **Structured Log Events**: System events like span start/end. Extensible to add more structured log types. - **Structured Log Events**: System events like span start/end. Extensible to add more structured log types.
```python ```python
structured_log_event = SpanStartPayload( structured_log_event = SpanStartPayload(name="my_span", parent_span_id="parent_span_id")
name="my_span",
parent_span_id="parent_span_id"
)
``` ```
### Spans and Traces ### Spans and Traces

View file

@ -35,7 +35,7 @@ Example client SDK call to register a "websearch" toolgroup that is provided by
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="brave-search", provider_id="brave-search",
args={"max_results": 5} args={"max_results": 5},
) )
``` ```
@ -50,8 +50,7 @@ The Code Interpreter allows execution of Python code within a controlled environ
```python ```python
# Register Code Interpreter tool group # Register Code Interpreter tool group
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::code_interpreter", toolgroup_id="builtin::code_interpreter", provider_id="code_interpreter"
provider_id="code_interpreter"
) )
``` ```
@ -68,16 +67,14 @@ The WolframAlpha tool provides access to computational knowledge through the Wol
```python ```python
# Register WolframAlpha tool group # Register WolframAlpha tool group
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::wolfram_alpha", toolgroup_id="builtin::wolfram_alpha", provider_id="wolfram-alpha"
provider_id="wolfram-alpha"
) )
``` ```
Example usage: Example usage:
```python ```python
result = client.tool_runtime.invoke_tool( result = client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", tool_name="wolfram_alpha", args={"query": "solve x^2 + 2x + 1 = 0"}
args={"query": "solve x^2 + 2x + 1 = 0"}
) )
``` ```
@ -90,10 +87,7 @@ The Memory tool enables retrieval of context from various types of memory banks
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::memory", toolgroup_id="builtin::memory",
provider_id="memory", provider_id="memory",
args={ args={"max_chunks": 5, "max_tokens_in_context": 4096},
"max_chunks": 5,
"max_tokens_in_context": 4096
}
) )
``` ```
@ -136,9 +130,7 @@ config = AgentConfig(
toolgroups=[ toolgroups=[
"builtin::websearch", "builtin::websearch",
], ],
client_tools=[ client_tools=[ToolDef(name="client_tool", description="Client provided tool")],
ToolDef(name="client_tool", description="Client provided tool")
]
) )
``` ```
@ -167,9 +159,9 @@ Example tool definition:
"name": "query", "name": "query",
"parameter_type": "string", "parameter_type": "string",
"description": "The query to search for", "description": "The query to search for",
"required": True "required": True,
} }
] ],
} }
``` ```
@ -179,8 +171,7 @@ Tools can be invoked using the `invoke_tool` method:
```python ```python
result = client.tool_runtime.invoke_tool( result = client.tool_runtime.invoke_tool(
tool_name="web_search", tool_name="web_search", kwargs={"query": "What is the capital of France?"}
kwargs={"query": "What is the capital of France?"}
) )
``` ```

View file

@ -1,9 +1,9 @@
# Using Llama Stack as a Library # Using Llama Stack as a Library
If you are planning to use an external service for Inference (even Ollama or TGI counts as external), it is often easier to use Llama Stack as a library. This avoids the overhead of setting up a server. If you are planning to use an external service for Inference (even Ollama or TGI counts as external), it is often easier to use Llama Stack as a library. This avoids the overhead of setting up a server.
```python ```bash
# setup # setup
pip install llama-stack uv pip install llama-stack
llama stack build --template together --image-type venv llama stack build --template together --image-type venv
``` ```
@ -13,7 +13,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
"ollama", "ollama",
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here. # provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data = {"tavily_search_api_key": os.environ['TAVILY_SEARCH_API_KEY']} provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
) )
await client.initialize() await client.initialize()
``` ```

View file

@ -96,18 +96,26 @@ Here is a simple example to perform chat completions using the SDK.
```python ```python
import os import os
def create_http_client(): def create_http_client():
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
return LlamaStackClient(
base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}"
)
def create_library_client(template="ollama"): def create_library_client(template="ollama"):
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient(template) client = LlamaStackAsLibraryClient(template)
client.initialize() client.initialize()
return client return client
client = create_library_client() # or create_http_client() depending on the environment you picked client = (
create_library_client()
) # or create_http_client() depending on the environment you picked
# List available models # List available models
models = client.models.list() models = client.models.list()
@ -120,8 +128,8 @@ response = client.inference.chat_completion(
model_id=os.environ["INFERENCE_MODEL"], model_id=os.environ["INFERENCE_MODEL"],
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a haiku about coding"} {"role": "user", "content": "Write a haiku about coding"},
] ],
) )
print(response.completion_message.content) print(response.completion_message.content)
``` ```
@ -139,7 +147,9 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types import Document from llama_stack_client.types import Document
client = create_library_client() # or create_http_client() depending on the environment you picked client = (
create_library_client()
) # or create_http_client() depending on the environment you picked
# Documents to be used for RAG # Documents to be used for RAG
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
@ -174,12 +184,12 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
enable_session_persistence=False, enable_session_persistence=False,
# Define tools available to the agent # Define tools available to the agent
toolgroups = [ toolgroups=[
{ {
"name": "builtin::rag", "name": "builtin::rag",
"args" : { "args": {
"vector_db_ids": [vector_db_id], "vector_db_ids": [vector_db_id],
} },
} }
], ],
) )
@ -193,7 +203,7 @@ user_prompts = [
# Run the agent loop by calling the `create_turn` method # Run the agent loop by calling the `create_turn` method
for prompt in user_prompts: for prompt in user_prompts:
cprint(f'User> {prompt}', 'green') cprint(f"User> {prompt}", "green")
response = rag_agent.create_turn( response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
session_id=session_id, session_id=session_id,

View file

@ -51,6 +51,7 @@ This first example walks you through how to evaluate a model candidate served by
```python ```python
import datasets import datasets
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev") ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"]) ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
eval_rows = ds.to_pandas().to_dict(orient="records") eval_rows = ds.to_pandas().to_dict(orient="records")
@ -79,7 +80,7 @@ system_message = {
client.eval_tasks.register( client.eval_tasks.register(
eval_task_id="meta-reference::mmmu", eval_task_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}", dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"] scoring_functions=["basic::regex_parser_multiple_choice_answer"],
) )
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
@ -98,9 +99,9 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096, "max_tokens": 4096,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
"system_message": system_message "system_message": system_message,
} },
} },
) )
``` ```
@ -124,7 +125,7 @@ _ = client.datasets.register(
"input_query": {"type": "string"}, "input_query": {"type": "string"},
"expected_answer": {"type": "string"}, "expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"}, "chat_completion_input": {"type": "chat_completion_input"},
} },
) )
eval_rows = client.datasetio.get_rows_paginated( eval_rows = client.datasetio.get_rows_paginated(
@ -137,7 +138,7 @@ eval_rows = client.datasetio.get_rows_paginated(
client.eval_tasks.register( client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa", eval_task_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id, dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"] scoring_functions=["llm-as-judge::405b-simpleqa"],
) )
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
@ -156,8 +157,8 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096, "max_tokens": 4096,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
}, },
} },
} },
) )
``` ```
@ -180,14 +181,14 @@ agent_config = {
{ {
"type": "brave_search", "type": "brave_search",
"engine": "tavily", "engine": "tavily",
"api_key": userdata.get("TAVILY_SEARCH_API_KEY") "api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
} }
], ],
"tool_choice": "auto", "tool_choice": "auto",
"tool_prompt_format": "json", "tool_prompt_format": "json",
"input_shields": [], "input_shields": [],
"output_shields": [], "output_shields": [],
"enable_session_persistence": False "enable_session_persistence": False,
} }
response = client.eval.evaluate_rows( response = client.eval.evaluate_rows(
@ -199,8 +200,8 @@ response = client.eval.evaluate_rows(
"eval_candidate": { "eval_candidate": {
"type": "agent", "type": "agent",
"config": agent_config, "config": agent_config,
} },
} },
) )
``` ```
@ -237,7 +238,9 @@ GENERATED_RESPONSE: {generated_answer}
EXPECTED_RESPONSE: {expected_answer} EXPECTED_RESPONSE: {expected_answer}
""" """
input_query = "What are the top 5 topics that were explained? Only list succinct bullet points." input_query = (
"What are the top 5 topics that were explained? Only list succinct bullet points."
)
generated_answer = """ generated_answer = """
Here are the top 5 topics that were explained in the documentation for Torchtune: Here are the top 5 topics that were explained in the documentation for Torchtune:
@ -268,7 +271,9 @@ scoring_params = {
"braintrust::factuality": None, "braintrust::factuality": None,
} }
response = client.scoring.score(input_rows=dataset_rows, scoring_functions=scoring_params) response = client.scoring.score(
input_rows=dataset_rows, scoring_functions=scoring_params
)
``` ```
## Running Evaluations via CLI ## Running Evaluations via CLI

View file

@ -33,7 +33,11 @@ from llama_stack_client.types import (
Types: Types:
```python ```python
from llama_stack_client.types import ListToolGroupsResponse, ToolGroup, ToolgroupListResponse from llama_stack_client.types import (
ListToolGroupsResponse,
ToolGroup,
ToolgroupListResponse,
)
``` ```
Methods: Methods:
@ -444,7 +448,11 @@ Methods:
Types: Types:
```python ```python
from llama_stack_client.types import EvalTask, ListEvalTasksResponse, EvalTaskListResponse from llama_stack_client.types import (
EvalTask,
ListEvalTasksResponse,
EvalTaskListResponse,
)
``` ```
Methods: Methods:

View file

@ -224,7 +224,7 @@ client = LlamaStackClient(base_url="http://localhost:5001")
response = client.inference.chat_completion( response = client.inference.chat_completion(
messages=[ messages=[
{"role": "system", "content": "You are a friendly assistant."}, {"role": "system", "content": "You are a friendly assistant."},
{"role": "user", "content": "Write a two-sentence poem about llama."} {"role": "user", "content": "Write a two-sentence poem about llama."},
], ],
model_id=INFERENCE_MODEL, model_id=INFERENCE_MODEL,
) )

View file

@ -86,9 +86,7 @@ class ShieldCallStep(StepCommon):
@json_schema_type @json_schema_type
class MemoryRetrievalStep(StepCommon): class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = ( step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
StepType.memory_retrieval.value
)
vector_db_ids: str vector_db_ids: str
inserted_context: InterleavedContent inserted_context: InterleavedContent
@ -184,9 +182,7 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type @json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel): class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = ( event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
AgentTurnResponseEventType.step_start.value
)
step_type: StepType step_type: StepType
step_id: str step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -194,9 +190,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
@json_schema_type @json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel): class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = ( event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
AgentTurnResponseEventType.step_complete.value
)
step_type: StepType step_type: StepType
step_id: str step_id: str
step_details: Step step_details: Step
@ -206,9 +200,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel): class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = ( event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
AgentTurnResponseEventType.step_progress.value
)
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -217,17 +209,13 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type @json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel): class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
AgentTurnResponseEventType.turn_start.value
)
turn_id: str turn_id: str
@json_schema_type @json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel): class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
AgentTurnResponseEventType.turn_complete.value
)
turn: Turn turn: Turn
@ -329,9 +317,7 @@ class Agents(Protocol):
toolgroups: Optional[List[AgentToolGroup]] = None, toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod( @webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
)
async def get_agents_turn( async def get_agents_turn(
self, self,
agent_id: str, agent_id: str,

View file

@ -63,9 +63,7 @@ class EventLogger:
if isinstance(chunk, ToolResponseMessage): if isinstance(chunk, ToolResponseMessage):
yield ( yield (
chunk, chunk,
LogEvent( LogEvent(role="CustomTool", content=chunk.content, color="grey"),
role="CustomTool", content=chunk.content, color="grey"
),
) )
continue continue
@ -81,17 +79,12 @@ class EventLogger:
step_type = event.payload.step_type step_type = event.payload.step_type
# handle safety # handle safety
if ( if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
step_type == StepType.shield_call
and event_type == EventType.step_complete.value
):
violation = event.payload.step_details.violation violation = event.payload.step_details.violation
if not violation: if not violation:
yield ( yield (
event, event,
LogEvent( LogEvent(role=step_type, content="No Violation", color="magenta"),
role=step_type, content="No Violation", color="magenta"
),
) )
else: else:
yield ( yield (
@ -110,9 +103,7 @@ class EventLogger:
# TODO: Currently this event is never received # TODO: Currently this event is never received
yield ( yield (
event, event,
LogEvent( LogEvent(role=step_type, content="", end="", color="yellow"),
role=step_type, content="", end="", color="yellow"
),
) )
elif event_type == EventType.step_progress.value: elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress # HACK: if previous was not step/event was not inference's step_progress
@ -125,9 +116,7 @@ class EventLogger:
): ):
yield ( yield (
event, event,
LogEvent( LogEvent(role=step_type, content="", end="", color="yellow"),
role=step_type, content="", end="", color="yellow"
),
) )
delta = event.payload.delta delta = event.payload.delta
@ -161,9 +150,7 @@ class EventLogger:
if event_type == EventType.step_complete.value: if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response response = event.payload.step_details.model_response
if response.tool_calls: if response.tool_calls:
content = ToolUtils.encode_tool_call( content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
response.tool_calls[0], tool_prompt_format
)
else: else:
content = response.content content = response.content
yield ( yield (
@ -202,10 +189,7 @@ class EventLogger:
), ),
) )
if ( if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
step_type == StepType.memory_retrieval
and event_type == EventType.step_complete.value
):
details = event.payload.step_details details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context) inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}" content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"

View file

@ -39,6 +39,4 @@ class DatasetIO(Protocol):
) -> PaginatedRowsResult: ... ) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/rows", method="POST") @webmethod(route="/datasetio/rows", method="POST")
async def append_rows( async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -63,9 +63,7 @@ class AppEvalTaskConfig(BaseModel):
EvalTaskConfig = register_schema( EvalTaskConfig = register_schema(
Annotated[ Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")],
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
],
name="EvalTaskConfig", name="EvalTaskConfig",
) )

View file

@ -245,9 +245,7 @@ class JsonSchemaResponseFormat(BaseModel):
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. :param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
""" """
type: Literal[ResponseFormatType.json_schema.value] = ( type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
ResponseFormatType.json_schema.value
)
json_schema: Dict[str, Any] json_schema: Dict[str, Any]
@ -406,9 +404,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
"""Generate a chat completion for the given messages using the specified model. """Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -89,9 +89,7 @@ class QATFinetuningConfig(BaseModel):
AlgorithmConfig = register_schema( AlgorithmConfig = register_schema(
Annotated[ Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
],
name="AlgorithmConfig", name="AlgorithmConfig",
) )
@ -204,14 +202,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status( async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET") @webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts( async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
self, job_uuid: str
) -> Optional[PostTrainingJobArtifactsResponse]: ...

View file

@ -23,9 +23,7 @@ class ResourceType(Enum):
class Resource(BaseModel): class Resource(BaseModel):
"""Base class for all Llama Stack resources""" """Base class for all Llama Stack resources"""
identifier: str = Field( identifier: str = Field(description="Unique identifier for this resource in llama stack")
description="Unique identifier for this resource in llama stack"
)
provider_resource_id: str = Field( provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider", description="Unique identifier for this resource in the provider",
@ -34,6 +32,4 @@ class Resource(BaseModel):
provider_id: str = Field(description="ID of the provider that owns this resource") provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field( type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)")
description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)"
)

View file

@ -43,9 +43,7 @@ class AggregationFunctionType(Enum):
@json_schema_type @json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel): class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ( type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
ScoringFnParamsType.llm_as_judge.value
)
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field( judge_score_regexes: Optional[List[str]] = Field(
@ -60,9 +58,7 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type @json_schema_type
class RegexParserScoringFnParams(BaseModel): class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ( type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
ScoringFnParamsType.regex_parser.value
)
parsing_regexes: Optional[List[str]] = Field( parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response", description="Regex to extract the answer from generated response",
default_factory=list, default_factory=list,
@ -112,9 +108,7 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type @json_schema_type
class ScoringFn(CommonScoringFnFields, Resource): class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ( type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
ResourceType.scoring_function.value
)
@property @property
def scoring_fn_id(self) -> str: def scoring_fn_id(self) -> str:
@ -141,9 +135,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
async def get_scoring_function( async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(

View file

@ -102,9 +102,7 @@ class StructuredLogType(Enum):
@json_schema_type @json_schema_type
class SpanStartPayload(BaseModel): class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = ( type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
StructuredLogType.SPAN_START.value
)
name: str name: str
parent_span_id: Optional[str] = None parent_span_id: Optional[str] = None
@ -190,9 +188,7 @@ class QuerySpanTreeResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST") @webmethod(route="/telemetry/events", method="POST")
async def log_event( async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
) -> None: ...
@webmethod(route="/telemetry/traces", method="GET") @webmethod(route="/telemetry/traces", method="GET")
async def query_traces( async def query_traces(

View file

@ -64,9 +64,7 @@ RAGQueryGeneratorConfig = register_schema(
class RAGQueryConfig(BaseModel): class RAGQueryConfig(BaseModel):
# This config defines how a query is generated using the messages # This config defines how a query is generated using the messages
# for memory bank retrieval. # for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field( query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
default=DefaultRAGQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096 max_tokens_in_context: int = 4096
max_chunks: int = 5 max_chunks: int = 5

View file

@ -150,8 +150,6 @@ class ToolRuntime(Protocol):
) -> List[ToolDef]: ... ) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool( async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments"""
... ...

View file

@ -147,9 +147,7 @@ class ParallelDownloader:
"follow_redirects": True, "follow_redirects": True,
} }
async def retry_with_exponential_backoff( async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
self, task: DownloadTask, func, *args, **kwargs
):
last_exception = None last_exception = None
for attempt in range(task.max_retries): for attempt in range(task.max_retries):
try: try:
@ -166,13 +164,9 @@ class ParallelDownloader:
continue continue
raise last_exception raise last_exception
async def get_file_info( async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
self, client: httpx.AsyncClient, task: DownloadTask
) -> None:
async def _get_info(): async def _get_info():
response = await client.head( response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
)
response.raise_for_status() response.raise_for_status()
return response return response
@ -201,14 +195,10 @@ class ParallelDownloader:
return False return False
return os.path.getsize(task.output_file) == task.total_size return os.path.getsize(task.output_file) == task.total_size
async def download_chunk( async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
) -> None:
async def _download_chunk(): async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"} headers = {"Range": f"bytes={start}-{end}"}
async with client.stream( async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
"GET", task.url, headers=headers, **self.client_options
) as response:
response.raise_for_status() response.raise_for_status()
with open(task.output_file, "ab") as file: with open(task.output_file, "ab") as file:
@ -225,8 +215,7 @@ class ParallelDownloader:
await self.retry_with_exponential_backoff(task, _download_chunk) await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e: except Exception as e:
raise DownloadError( raise DownloadError(
f"Failed to download chunk {start}-{end} after " f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
f"{task.max_retries} attempts: {str(e)}"
) from e ) from e
async def prepare_download(self, task: DownloadTask) -> None: async def prepare_download(self, task: DownloadTask) -> None:
@ -244,9 +233,7 @@ class ParallelDownloader:
# Check if file is already downloaded # Check if file is already downloaded
if os.path.exists(task.output_file): if os.path.exists(task.output_file):
if self.verify_file_integrity(task): if self.verify_file_integrity(task):
self.console.print( self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
f"[green]Already downloaded {task.output_file}[/green]"
)
self.progress.update(task.task_id, completed=task.total_size) self.progress.update(task.task_id, completed=task.total_size)
return return
@ -259,9 +246,7 @@ class ParallelDownloader:
current_pos = task.downloaded_size current_pos = task.downloaded_size
while current_pos < task.total_size: while current_pos < task.total_size:
chunk_end = min( chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
current_pos + chunk_size - 1, task.total_size - 1
)
chunks.append((current_pos, chunk_end)) chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1 current_pos = chunk_end + 1
@ -273,18 +258,12 @@ class ParallelDownloader:
raise DownloadError(f"Download failed: {str(e)}") from e raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e: except Exception as e:
self.progress.update( self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
task.task_id, description=f"[red]Failed: {task.output_file}[/red]" raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
)
raise DownloadError(
f"Download failed for {task.output_file}: {str(e)}"
) from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool: def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
try: try:
total_remaining_size = sum( total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
task.total_size - task.downloaded_size for task in tasks
)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file)) dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free free_space = shutil.disk_usage(dir_path).free
@ -314,9 +293,7 @@ class ParallelDownloader:
with self.progress: with self.progress:
for task in tasks: for task in tasks:
desc = f"Downloading {Path(task.output_file).name}" desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task( task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
desc, total=task.total_size, completed=task.downloaded_size
)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads) semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
@ -332,9 +309,7 @@ class ParallelDownloader:
if failed_tasks: if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]") self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks: for task, error in failed_tasks:
self.console.print( self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
f"[red]- {Path(task.output_file).name}: {error}[/red]"
)
raise DownloadError(f"{len(failed_tasks)} downloads failed") raise DownloadError(f"{len(failed_tasks)} downloads failed")
@ -396,11 +371,7 @@ def _meta_download(
output_file = str(output_dir / f) output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}") url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0 total_size = info.pth_size if "consolidated" in f else 0
tasks.append( tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
DownloadTask(
url=url, output_file=output_file, total_size=total_size, max_retries=3
)
)
# Initialize and run parallel downloader # Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads) downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
@ -446,14 +417,10 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()): if any(output_dir.iterdir()):
console.print( console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
)
while True: while True:
resp = input( resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() in ["restart", "r"]: if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -471,9 +438,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
] ]
# Initialize and run parallel downloader # Initialize and run parallel downloader
downloader = ParallelDownloader( downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
max_concurrent_downloads=max_concurrent_downloads
)
asyncio.run(downloader.download_all(tasks)) asyncio.run(downloader.download_all(tasks))

View file

@ -47,33 +47,20 @@ class ModelPromptFormat(Subcommand):
# Only Llama 3.1 and 3.2 are supported # Only Llama 3.1 and 3.2 are supported
supported_model_ids = [ supported_model_ids = [
m m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
for m in CoreModelId
if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
] ]
model_str = "\n".join([m.value for m in supported_model_ids]) model_str = "\n".join([m.value for m in supported_model_ids])
try: try:
model_id = CoreModelId(args.model_name) model_id = CoreModelId(args.model_name)
except ValueError: except ValueError:
self.parser.error( self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}"
)
if model_id not in supported_model_ids: if model_id not in supported_model_ids:
self.parser.error( self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
)
llama_3_1_file = ( llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
importlib.resources.files("llama_models") / "llama3_1/prompt_format.md" llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
) llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
llama_3_2_text_file = (
importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
)
llama_3_2_vision_file = (
importlib.resources.files("llama_models")
/ "llama3_2/vision_prompt_format.md"
)
if model_family(model_id) == ModelFamily.llama3_1: if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f: with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read() content = f.open("r").read()

View file

@ -17,16 +17,12 @@ class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
model_id: str = "Prompt-Guard-86M" model_id: str = "Prompt-Guard-86M"
description: str = ( description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
"Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
)
is_featured: bool = False is_featured: bool = False
huggingface_repo: str = "meta-llama/Prompt-Guard-86M" huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
max_seq_length: int = 2048 max_seq_length: int = 2048
is_instruct_model: bool = False is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = ( quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
CheckpointQuantizationFormat.bf16
)
arch_args: Dict[str, Any] = Field(default_factory=dict) arch_args: Dict[str, Any] = Field(default_factory=dict)
recommended_sampling_params: Optional[SamplingParams] = None recommended_sampling_params: Optional[SamplingParams] = None

View file

@ -56,9 +56,7 @@ def available_templates_specs() -> Dict[str, BuildConfig]:
return template_specs return template_specs
def run_stack_build_command( def run_stack_build_command(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> None:
if args.list_templates: if args.list_templates:
return _run_template_list_cmd() return _run_template_list_cmd()
@ -129,11 +127,7 @@ def run_stack_build_command(
providers = dict() providers = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
available_providers = [ available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
]
api_provider = prompt( api_provider = prompt(
"> Enter provider for API {}: ".format(api.value), "> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers), completer=WordCompleter(available_providers),
@ -156,9 +150,7 @@ def run_stack_build_command(
description=description, description=description,
) )
build_config = BuildConfig( build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
image_type=image_type, distribution_spec=distribution_spec
)
else: else:
with open(args.config, "r") as f: with open(args.config, "r") as f:
try: try:
@ -179,9 +171,7 @@ def run_stack_build_command(
if args.print_deps_only: if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}") print(f"# Dependencies for {args.template or args.config or image_name}")
normal_deps, special_deps = get_provider_dependencies( normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
build_config.distribution_spec.providers
)
normal_deps += SERVER_DEPENDENCIES normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}") print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps: for special_dep in special_deps:
@ -206,9 +196,7 @@ def _generate_run_config(
""" """
apis = list(build_config.distribution_spec.providers.keys()) apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig( run_config = StackRunConfig(
container_image=( container_image=(image_name if build_config.image_type == ImageType.container.value else None),
image_name if build_config.image_type == ImageType.container.value else None
),
image_name=image_name, image_name=image_name,
apis=apis, apis=apis,
providers={}, providers={},
@ -228,13 +216,9 @@ def _generate_run_config(
if p.deprecation_error: if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type( config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
provider_registry[Api(api)][provider_type].config_class
)
if hasattr(config_type, "sample_run_config"): if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config( config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
__distro_dir__=f"distributions/{image_name}"
)
else: else:
config = {} config = {}
@ -269,9 +253,7 @@ def _run_stack_build_command_from_build_config(
image_name = f"distribution-{template_name}" image_name = f"distribution-{template_name}"
else: else:
if not image_name: if not image_name:
raise ValueError( raise ValueError("Please specify an image name when building a container image without a template")
"Please specify an image name when building a container image without a template"
)
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == ImageType.conda.value:
if not image_name: if not image_name:
raise ValueError("Please specify an image name when building a conda image") raise ValueError("Please specify an image name when building a conda image")
@ -299,10 +281,7 @@ def _run_stack_build_command_from_build_config(
if template_name: if template_name:
# copy run.yaml from template to build_dir instead of generating it again # copy run.yaml from template to build_dir instead of generating it again
template_path = ( template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
importlib.resources.files("llama_stack")
/ f"templates/{template_name}/run.yaml"
)
with importlib.resources.as_file(template_path) as path: with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{template_name}-run.yaml" run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file) shutil.copy(path, run_config_file)

View file

@ -82,31 +82,21 @@ class StackRun(Subcommand):
if not config_file.exists() and not has_yaml_suffix: if not config_file.exists() and not has_yaml_suffix:
# check if this is a template # check if this is a template
config_file = ( config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
)
if config_file.exists(): if config_file.exists():
template_name = args.config template_name = args.config
if not config_file.exists() and not has_yaml_suffix: if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir # check if it's a build config saved to conda dir
config_file = Path( config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
)
if not config_file.exists() and not has_yaml_suffix: if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to container dir # check if it's a build config saved to container dir
config_file = Path( config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml"
)
if not config_file.exists() and not has_yaml_suffix: if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir # check if it's a build config saved to ~/.llama dir
config_file = Path( config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
DISTRIBS_BASE_DIR
/ f"llamastack-{args.config}"
/ f"{args.config}-run.yaml"
)
if not config_file.exists(): if not config_file.exists():
self.parser.error( self.parser.error(
@ -119,15 +109,8 @@ class StackRun(Subcommand):
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
if config.container_image: if config.container_image:
script = ( script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
importlib.resources.files("llama_stack") image_name = f"distribution-{template_name}" if template_name else config.container_image
/ "distribution/start_container.sh"
)
image_name = (
f"distribution-{template_name}"
if template_name
else config.container_image
)
run_args = [script, image_name] run_args = [script, image_name]
else: else:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
@ -145,11 +128,7 @@ class StackRun(Subcommand):
if env_name == "base": if env_name == "base":
return os.environ.get("CONDA_PREFIX") return os.environ.get("CONDA_PREFIX")
# Get conda environments info # Get conda environments info
conda_env_info = json.loads( conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
subprocess.check_output(
["conda", "info", "--envs", "--json"]
).decode()
)
envs = conda_env_info["envs"] envs = conda_env_info["envs"]
for envpath in envs: for envpath in envs:
if envpath.endswith(env_name): if envpath.endswith(env_name):
@ -173,10 +152,7 @@ class StackRun(Subcommand):
) )
return return
script = ( script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
importlib.resources.files("llama_stack")
/ "distribution/start_conda_env.sh"
)
run_args = [ run_args = [
script, script,
image_name, image_name,

View file

@ -22,11 +22,7 @@ def format_row(row, col_widths):
if line.strip() == "": if line.strip() == "":
lines.append("") lines.append("")
else: else:
lines.extend( lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
textwrap.wrap(
line, width, break_long_words=False, replace_whitespace=False
)
)
return lines return lines
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)] wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]

View file

@ -41,9 +41,7 @@ def up_to_date_config():
- provider_id: provider1 - provider_id: provider1
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: {{}} config: {{}}
""".format( """.format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
)
) )
@ -83,9 +81,7 @@ def old_config():
telemetry: telemetry:
provider_type: noop provider_type: noop
config: {{}} config: {{}}
""".format( """.format(built_at=datetime.now().isoformat())
built_at=datetime.now().isoformat()
)
) )
@ -108,10 +104,7 @@ def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
def test_parse_and_maybe_upgrade_config_old_format(old_config): def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config) result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert all( assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
api in result.providers
for api in ["inference", "safety", "memory", "telemetry"]
)
safety_provider = result.providers["safety"][0] safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "meta-reference" assert safety_provider.provider_type == "meta-reference"
assert "llama_guard_shield" in safety_provider.config assert "llama_guard_shield" in safety_provider.config

View file

@ -72,9 +72,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
return checksums return checksums
def verify_files( def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
model_dir: Path, checksums: Dict[str, str], console: Console
) -> List[VerificationResult]:
results = [] results = []
with Progress( with Progress(

View file

@ -58,22 +58,14 @@ def get_provider_dependencies(
for api_str, provider_or_providers in config_providers.items(): for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)] providers_for_api = all_providers[Api(api_str)]
providers = ( providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
provider_or_providers
if isinstance(provider_or_providers, list)
else [provider_or_providers]
)
for provider in providers: for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different  not great # Providers from BuildConfig and RunConfig are subtly different  not great
provider_type = ( provider_type = provider if isinstance(provider, str) else provider.provider_type
provider if isinstance(provider, str) else provider.provider_type
)
if provider_type not in providers_for_api: if provider_type not in providers_for_api:
raise ValueError( raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider_type] provider_spec = providers_for_api[provider_type]
deps.extend(provider_spec.pip_packages) deps.extend(provider_spec.pip_packages)
@ -109,19 +101,13 @@ def build_image(
image_name: str, image_name: str,
template_or_config: str, template_or_config: str,
): ):
container_base = ( container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
build_config.distribution_spec.container_image or "python:3.10-slim"
)
normal_deps, special_deps = get_provider_dependencies( normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
build_config.distribution_spec.providers
)
normal_deps += SERVER_DEPENDENCIES normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.container.value: if build_config.image_type == ImageType.container.value:
script = str( script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
)
args = [ args = [
script, script,
template_or_config, template_or_config,
@ -132,9 +118,7 @@ def build_image(
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == ImageType.conda.value:
script = str( script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
)
args = [ args = [
script, script,
str(image_name), str(image_name),
@ -142,9 +126,7 @@ def build_image(
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.venv.value: elif build_config.image_type == ImageType.venv.value:
script = str( script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
)
args = [ args = [
script, script,
str(image_name), str(image_name),

View file

@ -68,9 +68,7 @@ def create_api_client_class(protocol) -> Type:
return_type = None return_type = None
else: else:
return_type = extract_non_async_iterator_type(sig.return_annotation) return_type = extract_non_async_iterator_type(sig.return_annotation)
assert return_type, ( assert return_type, f"Could not extract return type for {sig.return_annotation}"
f"Could not extract return type for {sig.return_annotation}"
)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs) params = self.httpx_request_params(method_name, *args, **kwargs)
@ -87,9 +85,7 @@ def create_api_client_class(protocol) -> Type:
webmethod, sig = self.routes[method_name] webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation) return_type = extract_async_iterator_type(sig.return_annotation)
assert return_type, ( assert return_type, f"Could not extract return type for {sig.return_annotation}"
f"Could not extract return type for {sig.return_annotation}"
)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs) params = self.httpx_request_params(method_name, *args, **kwargs)
@ -204,9 +200,7 @@ async def example(model: str = None):
if not model: if not model:
model = "Llama3.2-3B-Instruct" model = "Llama3.2-3B-Instruct"
message = UserMessage( message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
content="hello world, write me a 2 sentence poem about the moon"
)
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
stream = True stream = True

View file

@ -26,9 +26,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def configure_single_provider( def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
provider_spec = registry[provider.provider_type] provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
try: try:
@ -47,9 +45,7 @@ def configure_single_provider(
) )
def configure_api_providers( def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
is_nux = len(config.providers) == 0 is_nux = len(config.providers) == 0
if is_nux: if is_nux:
@ -87,9 +83,7 @@ def configure_api_providers(
updated_providers = [] updated_providers = []
for p in existing_providers: for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`") logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append( updated_providers.append(configure_single_provider(provider_registry[api], p))
configure_single_provider(provider_registry[api], p)
)
logger.info("") logger.info("")
else: else:
# we are newly configuring this API # we are newly configuring this API
@ -114,11 +108,7 @@ def configure_api_providers(
configure_single_provider( configure_single_provider(
provider_registry[api], provider_registry[api],
Provider( Provider(
provider_id=( provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_type=provider_type, provider_type=provider_type,
config={}, config={},
), ),
@ -137,11 +127,7 @@ def upgrade_from_routing_table(
def get_providers(entries): def get_providers(entries):
return [ return [
Provider( Provider(
provider_id=( provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_type=entry["provider_type"], provider_type=entry["provider_type"],
config=entry["config"], config=entry["config"],
) )

View file

@ -163,9 +163,7 @@ a default SQLite store will be used.""",
class BuildConfig(BaseModel): class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field( distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
description="The distribution spec to build including API providers. "
)
image_type: str = Field( image_type: str = Field(
default="conda", default="conda",
description="Type of package to build (conda | container | venv)", description="Type of package to build (conda | container | venv)",

View file

@ -55,9 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]: def providable_apis() -> List[Api]:
routing_table_apis = set( routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
x.routing_table_api for x in builtin_automatically_routed_apis()
)
return [api for api in Api if api not in routing_table_apis and api != Api.inspect] return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -154,9 +154,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def sync_generator(): def sync_generator():
try: try:
async_stream = loop.run_until_complete( async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
self.async_client.request(*args, **kwargs)
)
while True: while True:
chunk = loop.run_until_complete(async_stream.__anext__()) chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk yield chunk
@ -181,9 +179,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
# of our logs are intended for server-side usage # of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join( os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
sink for sink in current_sinks if sink != "console"
)
if config_path_or_template_name.endswith(".yaml"): if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name) config_path = Path(config_path_or_template_name)
@ -202,9 +198,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def initialize(self): async def initialize(self):
try: try:
self.impls = await construct_stack( self.impls = await construct_stack(self.config, self.custom_provider_registry)
self.config, self.custom_provider_registry
)
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, "red") cprint(_e.msg, "red")
cprint( cprint(
@ -247,9 +241,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func = getattr(impl, endpoint.name) func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls: if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {} endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][ endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
_convert_path_to_regex(endpoint.route)
] = func
self.endpoint_impls = endpoint_impls self.endpoint_impls = endpoint_impls
return True return True
@ -266,9 +258,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
if self.provider_data: if self.provider_data:
set_request_provider_data( set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
if stream: if stream:
response = await self._call_streaming( response = await self._call_streaming(
@ -408,9 +398,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body( def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
self, path: str, method: str, body: Optional[dict] = None
) -> dict:
if not body: if not body:
return {} return {}
@ -425,7 +413,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items(): for param_name, param in sig.parameters.items():
if param_name in body: if param_name in body:
value = body.get(param_name) value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic( converted_body[param_name] = convert_to_pydantic(param.annotation, value)
param.annotation, value
)
return converted_body return converted_body

View file

@ -115,9 +115,7 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
routing_table_apis = set( routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
x.routing_table_api for x in builtin_automatically_routed_apis()
)
router_apis = set(x.router_api for x in builtin_automatically_routed_apis()) router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
providers_with_specs = {} providers_with_specs = {}
@ -125,16 +123,12 @@ async def resolve_impls(
for api_str, providers in run_config.providers.items(): for api_str, providers in run_config.providers.items():
api = Api(api_str) api = Api(api_str)
if api in routing_table_apis: if api in routing_table_apis:
raise ValueError( raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {} specs = {}
for provider in providers: for provider in providers:
if provider.provider_type not in provider_registry[api]: if provider.provider_type not in provider_registry[api]:
raise ValueError( raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
@ -145,9 +139,7 @@ async def resolve_impls(
log.warning( log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
) )
p.deps__ = [a.value for a in p.api_dependencies] + [ p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
a.value for a in p.optional_api_dependencies
]
spec = ProviderWithSpec( spec = ProviderWithSpec(
spec=p, spec=p,
**(provider.model_dump()), **(provider.model_dump()),
@ -158,9 +150,7 @@ async def resolve_impls(
providers_with_specs[key] = specs providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set( apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
) )
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
@ -197,9 +187,7 @@ async def resolve_impls(
) )
} }
sorted_providers = topological_sort( sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
{k: v.values() for k, v in providers_with_specs.items()}
)
apis = [x[1].spec.api for x in sorted_providers] apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append( sorted_providers.append(
( (
@ -237,9 +225,7 @@ async def resolve_impls(
inner_impls = {} inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[ inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider( impl = await instantiate_provider(
provider, provider,
@ -336,10 +322,7 @@ async def instantiate_provider(
# TODO: check compliance for special tool groups # TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api]) check_protocol_compliance(impl, protocols[provider_spec.api])
if ( if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _, _ = additional_protocols[provider_spec.api] additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api) check_protocol_compliance(impl, additional_api)
@ -367,19 +350,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
log.error( log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class
method_owner = next( method_owner = next((cls for cls in mro if name in cls.__dict__), None)
(cls for cls in mro if name in cls.__dict__), None if method_owner is None or method_owner.__name__ == protocol.__name__:
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented")) missing_methods.append((name, "not_actually_implemented"))
if missing_methods: if missing_methods:

View file

@ -85,9 +85,7 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
vector_db_id, chunks, ttl_seconds
)
async def query_chunks( async def query_chunks(
self, self,
@ -95,9 +93,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
vector_db_id, query, params
)
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -123,9 +119,7 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
await self.routing_table.register_model( await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
model_id, provider_model_id, provider_id, metadata, model_type
)
async def chat_completion( async def chat_completion(
self, self,
@ -143,9 +137,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
messages=messages, messages=messages,
@ -176,9 +168,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -202,9 +192,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError( raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
@ -231,9 +219,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
return await self.routing_table.register_shield( return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
shield_id, provider_shield_id, provider_id, params
)
async def run_shield( async def run_shield(
self, self,
@ -268,9 +254,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None, page_token: Optional[str] = None,
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, rows_in_page=rows_in_page,
page_token=page_token, page_token=page_token,
@ -305,9 +289,7 @@ class ScoringRouter(Scoring):
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
fn_identifier
).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -328,9 +310,7 @@ class ScoringRouter(Scoring):
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
fn_identifier
).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -381,9 +361,7 @@ class EvalRouter(Eval):
task_id: str, task_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status( return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
task_id, job_id
)
async def job_cancel( async def job_cancel(
self, self,
@ -420,9 +398,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None, query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl("query_from_memory").query(
"query_from_memory" content, vector_db_ids, query_config
).query(content, vector_db_ids, query_config) )
async def insert( async def insert(
self, self,
@ -430,9 +408,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
"insert_into_memory" documents, vector_db_id, chunk_size_in_tokens
).insert(documents, vector_db_id, chunk_size_in_tokens) )
def __init__( def __init__(
self, self,
@ -460,6 +438,4 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools( return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
tool_group_id, mcp_endpoint
)

View file

@ -94,9 +94,7 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry self.dist_registry = dist_registry
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects( async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs: for obj in objs:
if cls is None: if cls is None:
obj.provider_id = provider_id obj.provider_id = provider_id
@ -131,9 +129,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl( def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object(): def apiname_object():
if isinstance(self, ModelsRoutingTable): if isinstance(self, ModelsRoutingTable):
return ("Inference", "model") return ("Inference", "model")
@ -171,9 +167,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`") raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier( async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry # Get from disk registry
obj = await self.dist_registry.get(type, identifier) obj = await self.dist_registry.get(type, identifier)
if not obj: if not obj:
@ -183,13 +177,9 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier) await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider( await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
obj, self.impls_by_provider_id[obj.provider_id]
)
async def register_object( async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries # if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0: if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0] obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -244,9 +234,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None: if model_type is None:
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError( raise ValueError("Embedding model must have an embedding dimension in its metadata")
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
@ -266,9 +254,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse( return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Optional[Shield]: async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier) return await self.get_object_by_identifier("shield", identifier)
@ -340,9 +326,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError( raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
f"Model {embedding_model} does not have an embedding dimension"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_db_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
@ -364,9 +348,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse( return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id) return await self.get_object_by_identifier("dataset", dataset_id)
@ -411,9 +393,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse( return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id) return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -510,12 +490,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
toolgroup_id, mcp_endpoint tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(

View file

@ -43,9 +43,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
if api == Api.tool_runtime: if api == Api.tool_runtime:
for tool_group in SpecialToolGroup: for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group] sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers( sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
sub_protocol, predicate=inspect.isfunction
)
for name, method in sub_protocol_methods: for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"): if not hasattr(method, "__webmethod__"):
continue continue

View file

@ -76,9 +76,7 @@ async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc) traceback.print_exception(exc)
http_exc = translate_exception(exc) http_exc = translate_exception(exc)
return JSONResponse( return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]: def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
@ -178,9 +176,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:
if is_streaming: if is_streaming:
return StreamingResponse( return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else: else:
value = func(**kwargs) value = func(**kwargs)
return await maybe_await(value) return await maybe_await(value)
@ -190,11 +186,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
sig = inspect.signature(func) sig = inspect.signature(func)
new_params = [ new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
]
new_params.extend(sig.parameters.values()) new_params.extend(sig.parameters.values())
path_params = extract_path_params(route) path_params = extract_path_params(route)
@ -202,15 +194,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
# Annotate parameters that are in the path with Path(...) and others with Body(...) # Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [ new_params = [new_params[0]] + [
( (
param.replace( param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
annotation=Annotated[
param.annotation, FastapiPath(..., title=param.name)
]
)
if param.name in path_params if param.name in path_params
else param.replace( else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
) )
for param in new_params[1:] for param in new_params[1:]
] ]
@ -244,12 +230,8 @@ class ClientVersionMiddleware:
client_version = headers.get(b"x-llamastack-client-version", b"").decode() client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version: if client_version:
try: try:
client_version_parts = tuple( client_version_parts = tuple(map(int, client_version.split(".")[:2]))
map(int, client_version.split(".")[:2]) server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
if client_version_parts != server_version_parts: if client_version_parts != server_version_parts:
async def send_version_error(send): async def send_version_error(send):
@ -267,9 +249,7 @@ class ClientVersionMiddleware:
} }
} }
).encode() ).encode()
await send( await send({"type": "http.response.body", "body": error_msg})
{"type": "http.response.body", "body": error_msg}
)
return await send_version_error(send) return await send_version_error(send)
except (ValueError, IndexError): except (ValueError, IndexError):
@ -296,9 +276,7 @@ def main():
default=int(os.getenv("LLAMA_STACK_PORT", 8321)), default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on", help="Port to listen on",
) )
parser.add_argument( parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
)
parser.add_argument( parser.add_argument(
"--env", "--env",
action="append", action="append",
@ -323,9 +301,7 @@ def main():
raise ValueError(f"Config file {config_file} does not exist") raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}") print(f"Using config file: {config_file}")
elif args.template: elif args.template:
config_file = ( config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
)
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist") raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}") print(f"Using template {args.template} config file: {config_file}")
@ -383,9 +359,7 @@ def main():
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings( warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
"ignore", category=UserWarning, module="pydantic._internal._fields"
)
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,
@ -416,9 +390,7 @@ def main():
def extract_path_params(route: str) -> List[str]: def extract_path_params(route: str) -> List[str]:
segments = route.split("/") segments = route.split("/")
params = [ params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
]
return params return params

View file

@ -110,9 +110,7 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name self.var_name = var_name
self.path = path self.path = path
super().__init__( super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
@ -187,9 +185,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
if not key: if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}") raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key): if not all(c.isalnum() or c == "_" for c in key):
raise ValueError( raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value return key, value
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
@ -202,20 +198,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
async def construct_stack( async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry( dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
run_config.metadata_store, run_config.image_name impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls) await register_resources(run_config, impls)
return impls return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig: def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = ( template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
with importlib.resources.as_file(template_path) as path: with importlib.resources.as_file(template_path) as path:
if not path.exists(): if not path.exists():

View file

@ -25,9 +25,7 @@ class DistributionRegistry(Protocol):
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
async def update( async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider: ...
async def register(self, obj: RoutableObjectWithProvider) -> bool: ... async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
@ -61,9 +59,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
def get_cached( def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Disk registry does not have a cache # Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache") raise NotImplementedError("Disk registry does not have a cache")
@ -72,12 +68,8 @@ class DiskDistributionRegistry(DistributionRegistry):
values = await self.kvstore.range(start_key, end_key) values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values) return _parse_registry_values(values)
async def get( async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
) -> Optional[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
if not json_str: if not json_str:
return None return None
@ -143,9 +135,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def initialize(self) -> None: async def initialize(self) -> None:
await self._ensure_initialized() await self._ensure_initialized()
def get_cached( def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), None) return self.cache.get((type, identifier), None)
async def get_all(self) -> List[RoutableObjectWithProvider]: async def get_all(self) -> List[RoutableObjectWithProvider]:
@ -153,9 +143,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async with self._locked_cache() as cache: async with self._locked_cache() as cache:
return list(cache.values()) return list(cache.values())
async def get( async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
await self._ensure_initialized() await self._ensure_initialized()
cache_key = (type, identifier) cache_key = (type, identifier)
@ -197,9 +185,7 @@ async def create_dist_registry(
dist_kvstore = await kvstore_impl(metadata_store) dist_kvstore = await kvstore_impl(metadata_store)
else: else:
dist_kvstore = await kvstore_impl( dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig( SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
) )
dist_registry = CachedDiskDistributionRegistry(dist_kvstore) dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize() await dist_registry.initialize()

View file

@ -161,9 +161,7 @@ async def test_duplicate_provider_registration(config):
result = await cached_registry.get("vector_db", "test_vector_db_2") result = await cached_registry.get("vector_db", "test_vector_db_2")
assert result is not None assert result is not None
assert ( assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
result.embedding_model == original_vector_db.embedding_model
) # Original values preserved
@pytest.mark.asyncio @pytest.mark.asyncio
@ -193,14 +191,9 @@ async def test_get_all_objects(config):
# Verify each vector_db was stored correctly # Verify each vector_db was stored correctly
for original_vector_db in test_vector_dbs: for original_vector_db in test_vector_dbs:
matching_vector_dbs = [ matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
v for v in all_results if v.identifier == original_vector_db.identifier
]
assert len(matching_vector_dbs) == 1 assert len(matching_vector_dbs) == 1
stored_vector_db = matching_vector_dbs[0] stored_vector_db = matching_vector_dbs[0]
assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id assert stored_vector_db.provider_id == original_vector_db.provider_id
assert ( assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
stored_vector_db.embedding_dimension
== original_vector_db.embedding_dimension
)

View file

@ -22,15 +22,11 @@ def main():
) )
# Playground pages # Playground pages
chat_page = st.Page( chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
"page/playground/chat.py", title="Chat", icon="💬", default=True
)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False) rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
# Distribution pages # Distribution pages
resources_page = st.Page( resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
"page/distribution/resources.py", title="Resources", icon="🔍", default=False
)
provider_page = st.Page( provider_page = st.Page(
"page/distribution/providers.py", "page/distribution/providers.py",
title="API Providers", title="API Providers",

View file

@ -23,15 +23,11 @@ class LlamaStackApi:
}, },
) )
def run_scoring( def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
):
"""Run scoring on a single row""" """Run scoring on a single row"""
if not scoring_params: if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids} scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score( return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
input_rows=[row], scoring_functions=scoring_params
)
llama_stack_api = LlamaStackApi() llama_stack_api = LlamaStackApi()

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
def datasets(): def datasets():
st.header("Datasets") st.header("Datasets")
datasets_info = { datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
}
if len(datasets_info) > 0: if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys())) selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True) st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -12,12 +12,8 @@ def eval_tasks():
# Eval Tasks Section # Eval Tasks Section
st.header("Eval Tasks") st.header("Eval Tasks")
eval_tasks_info = { eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
}
if len(eval_tasks_info) > 0: if len(eval_tasks_info) > 0:
selected_eval_task = st.selectbox( selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
)
st.json(eval_tasks_info[selected_eval_task], expanded=True) st.json(eval_tasks_info[selected_eval_task], expanded=True)

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
def models(): def models():
# Models Section # Models Section
st.header("Models") st.header("Models")
models_info = { models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()
}
selected_model = st.selectbox("Select a model", list(models_info.keys())) selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model]) st.json(models_info[selected_model])

View file

@ -11,12 +11,7 @@ from modules.api import llama_stack_api
def scoring_functions(): def scoring_functions():
st.header("Scoring Functions") st.header("Scoring Functions")
scoring_functions_info = { scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
s.identifier: s.to_dict()
for s in llama_stack_api.client.scoring_functions.list()
}
selected_scoring_function = st.selectbox( selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
"Select a scoring function", list(scoring_functions_info.keys())
)
st.json(scoring_functions_info[selected_scoring_function], expanded=True) st.json(scoring_functions_info[selected_scoring_function], expanded=True)

View file

@ -12,9 +12,7 @@ def shields():
# Shields Section # Shields Section
st.header("Shields") st.header("Shields")
shields_info = { shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()
}
selected_shield = st.selectbox("Select a shield", list(shields_info.keys())) selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
st.json(shields_info[selected_shield]) st.json(shields_info[selected_shield])

View file

@ -10,14 +10,10 @@ from modules.api import llama_stack_api
def vector_dbs(): def vector_dbs():
st.header("Vector Databases") st.header("Vector Databases")
vector_dbs_info = { vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()
}
if len(vector_dbs_info) > 0: if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox( selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
"Select a vector database", list(vector_dbs_info.keys())
)
st.json(vector_dbs_info[selected_vector_db]) st.json(vector_dbs_info[selected_vector_db])
else: else:
st.info("No vector databases found") st.info("No vector databases found")

View file

@ -14,7 +14,6 @@ from modules.utils import process_dataset
def application_evaluation_page(): def application_evaluation_page():
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙") st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Scoring)") st.title("📊 Evaluations (Scoring)")
@ -83,9 +82,7 @@ def application_evaluation_page():
try: try:
new_params[param_name] = json.loads(value) new_params[param_name] = json.loads(value)
except json.JSONDecodeError: except json.JSONDecodeError:
st.error( st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
)
st.json(new_params) st.json(new_params)
scoring_params[scoring_fn_id] = new_params scoring_params[scoring_fn_id] = new_params
@ -128,9 +125,7 @@ def application_evaluation_page():
output_res[fn_id].append(score_res.results[fn_id].score_rows[0]) output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers # Display current row results using separate containers
progress_text_container.write( progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
results_container.json( results_container.json(
score_res.to_json(), score_res.to_json(),
expanded=2, expanded=2,

View file

@ -195,7 +195,6 @@ def run_evaluation_3():
# Add run button and handle evaluation # Add run button and handle evaluation
if st.button("Run Evaluation"): if st.button("Run Evaluation"):
progress_text = "Running evaluation..." progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text) progress_bar = st.progress(0, text=progress_text)
rows = rows.rows rows = rows.rows
@ -233,9 +232,7 @@ def run_evaluation_3():
output_res[scoring_fn] = [] output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write( progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
results_container.json(eval_res, expanded=2) results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!") progress_bar.progress(1.0, text="Evaluation complete!")
@ -247,7 +244,6 @@ def run_evaluation_3():
def native_evaluation_page(): def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙") st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)") st.title("📊 Evaluations (Generation + Scoring)")

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
with st.sidebar: with st.sidebar:
st.header("Configuration") st.header("Configuration")
available_models = llama_stack_api.client.models.list() available_models = llama_stack_api.client.models.list()
available_models = [ available_models = [model.identifier for model in available_models if model.model_type == "llm"]
model.identifier for model in available_models if model.model_type == "llm"
]
selected_model = st.selectbox( selected_model = st.selectbox(
"Choose a model", "Choose a model",
available_models, available_models,
@ -128,6 +126,4 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
full_response = response full_response = response
message_placeholder.markdown(full_response.completion_message.content) message_placeholder.markdown(full_response.completion_message.content)
st.session_state.messages.append( st.session_state.messages.append({"role": "assistant", "content": full_response})
{"role": "assistant", "content": full_response}
)

View file

@ -74,9 +74,7 @@ def rag_chat_page():
) )
available_models = llama_stack_api.client.models.list() available_models = llama_stack_api.client.models.list()
available_models = [ available_models = [model.identifier for model in available_models if model.model_type == "llm"]
model.identifier for model in available_models if model.model_type == "llm"
]
selected_model = st.selectbox( selected_model = st.selectbox(
"Choose a model", "Choose a model",
available_models, available_models,
@ -137,9 +135,7 @@ def rag_chat_page():
dict( dict(
name="builtin::rag", name="builtin::rag",
args={ args={
"vector_db_ids": [ "vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
vector_db_id for vector_db_id in selected_vector_dbs
],
}, },
) )
], ],
@ -186,9 +182,7 @@ def rag_chat_page():
message_placeholder.markdown(full_response + "") message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
st.session_state.messages.append( st.session_state.messages.append({"role": "assistant", "content": full_response})
{"role": "assistant", "content": full_response}
)
rag_chat_page() rag_chat_page()

View file

@ -8,9 +8,7 @@ import os
from pathlib import Path from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path( LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))
)
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions" DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"

View file

@ -31,15 +31,11 @@ def is_list_of_primitives(field_type):
def is_basemodel_without_fields(typ): def is_basemodel_without_fields(typ):
return ( return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
)
def can_recurse(typ): def can_recurse(typ):
return ( return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
)
def get_literal_values(field): def get_literal_values(field):
@ -72,7 +68,7 @@ def is_discriminated_union(typ) -> bool:
if isinstance(typ, FieldInfo): if isinstance(typ, FieldInfo):
return typ.discriminator return typ.discriminator
else: else:
if not (get_origin(typ) is Annotated): if get_origin(typ) is not Annotated:
return False return False
args = get_args(typ) args = get_args(typ)
return len(args) >= 2 and args[1].discriminator return len(args) >= 2 and args[1].discriminator
@ -116,9 +112,7 @@ def prompt_for_discriminated_union(
chosen_type = type_map[discriminator_value] chosen_type = type_map[discriminator_value]
log.info(f"\nConfiguring {chosen_type.__name__}:") log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and ( if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
getattr(existing_value, discriminator) != discriminator_value
):
existing_value = None existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value) sub_config = prompt_for_config(chosen_type, existing_value)
@ -134,9 +128,7 @@ def prompt_for_discriminated_union(
# #
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of # doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage. # unit tests for coverage.
def prompt_for_config( def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
""" """
Recursively prompt the user for configuration values based on a Pydantic BaseModel. Recursively prompt the user for configuration values based on a Pydantic BaseModel.
@ -150,17 +142,11 @@ def prompt_for_config(
for field_name, field in config_type.__fields__.items(): for field_name, field in config_type.__fields__.items():
field_type = field.annotation field_type = field.annotation
existing_value = ( existing_value = getattr(existing_config, field_name) if existing_config else None
getattr(existing_config, field_name) if existing_config else None
)
if existing_value: if existing_value:
default_value = existing_value default_value = existing_value
else: else:
default_value = ( default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
field.default
if not isinstance(field.default, PydanticUndefinedType)
else None
)
is_required = field.is_required is_required = field.is_required
# Skip fields with Literal type # Skip fields with Literal type
@ -183,15 +169,11 @@ def prompt_for_config(
config_data[field_name] = validated_value config_data[field_name] = validated_value
break break
except KeyError: except KeyError:
log.error( log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
continue continue
if is_discriminated_union(field): if is_discriminated_union(field):
config_data[field_name] = prompt_for_discriminated_union( config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
field_name, field, existing_value
)
continue continue
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)): if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
@ -202,9 +184,7 @@ def prompt_for_config(
nested_type = get_non_none_type(field_type) nested_type = get_non_none_type(field_type)
log.info(f"Entering sub-configuration for {field_name}:") log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value) config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union( elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
get_non_none_type(field_type)
):
prompt = f"Do you want to configure {field_name}? (y/n): " prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n": if input(prompt).lower() == "n":
config_data[field_name] = None config_data[field_name] = None
@ -260,16 +240,12 @@ def prompt_for_config(
try: try:
value = json.loads(user_input) value = json.loads(user_input)
if not isinstance(value, list): if not isinstance(value, list):
raise ValueError( raise ValueError("Input must be a JSON-encoded list")
"Input must be a JSON-encoded list"
)
element_type = get_args(field_type)[0] element_type = get_args(field_type)[0]
value = [element_type(item) for item in value] value = [element_type(item) for item in value]
except json.JSONDecodeError: except json.JSONDecodeError:
log.error( log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
)
continue continue
except ValueError as e: except ValueError as e:
log.error(f"{str(e)}") log.error(f"{str(e)}")
@ -279,20 +255,14 @@ def prompt_for_config(
try: try:
value = json.loads(user_input) value = json.loads(user_input)
if not isinstance(value, dict): if not isinstance(value, dict):
raise ValueError( raise ValueError("Input must be a JSON-encoded dictionary")
"Input must be a JSON-encoded dictionary"
)
except json.JSONDecodeError: except json.JSONDecodeError:
log.error( log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue continue
# Convert the input to the correct type # Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass( elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
field_type, BaseModel
):
# For nested BaseModels, we assume a dictionary-like string input # For nested BaseModels, we assume a dictionary-like string input
import ast import ast
@ -301,16 +271,12 @@ def prompt_for_config(
value = field_type(user_input) value = field_type(user_input)
except ValueError: except ValueError:
log.error( log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
continue continue
try: try:
# Validate the field using our manual validation function # Validate the field using our manual validation function
validated_value = manually_validate_field( validated_value = manually_validate_field(config_type, field_name, value)
config_type, field_name, value
)
config_data[field_name] = validated_value config_data[field_name] = validated_value
break break
except ValueError as e: except ValueError as e:

View file

@ -11,9 +11,7 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl( async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(

View file

@ -74,9 +74,7 @@ log = logging.getLogger(__name__)
def make_random_string(length: int = 8): def make_random_string(length: int = 8):
return "".join( return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -153,9 +151,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
async def create_and_execute_turn( async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span: with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
@ -206,14 +202,9 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk output_message = chunk
continue continue
assert isinstance( assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event event = chunk.event
if ( if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details) steps.append(event.payload.step_details)
yield chunk yield chunk
@ -388,9 +379,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents: if documents:
await self.handle_documents( await self.handle_documents(session_id, documents, input_messages, tool_defs)
session_id, documents, input_messages, tool_defs
)
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0: if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span: with tracing.span(MEMORY_QUERY_TOOL) as span:
@ -408,9 +397,7 @@ class ChatAgent(ShieldRunnerMixin):
vector_db_ids = args.get("vector_db_ids", []) vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config") query_config = args.get("query_config")
if query_config: if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python( query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
query_config
)
else: else:
# handle someone passing an empty dict # handle someone passing an empty dict
query_config = RAGQueryConfig() query_config = RAGQueryConfig()
@ -438,9 +425,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
result = await self.tool_runtime_api.rag_tool.query( result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content( content=concat_interleaved_content([msg.content for msg in input_messages]),
[msg.content for msg in input_messages]
),
vector_db_ids=vector_db_ids, vector_db_ids=vector_db_ids,
query_config=query_config, query_config=query_config,
) )
@ -472,9 +457,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
span.set_attribute( span.set_attribute("input", [m.model_dump_json() for m in input_messages])
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", retrieved_context) span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL) span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
@ -511,9 +494,7 @@ class ChatAgent(ShieldRunnerMixin):
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
tools=[ tools=[
tool tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
], ],
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
response_format=self.agent_config.response_format, response_format=self.agent_config.response_format,
@ -560,12 +541,8 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None: if event.stop_reason is not None:
stop_reason = event.stop_reason stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason) span.set_attribute("stop_reason", stop_reason)
span.set_attribute( span.set_attribute("input", [m.model_dump_json() for m in input_messages])
"input", [m.model_dump_json() for m in input_messages] span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
stop_reason = stop_reason or StopReason.out_of_tokens stop_reason = stop_reason or StopReason.out_of_tokens
@ -667,9 +644,7 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_args, toolgroup_args,
tool_to_group, tool_to_group,
) )
assert ( assert len(result_messages) == 1, "Currently not supporting multiple messages"
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0] result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json()) span.set_attribute("output", result_message.model_dump_json())
@ -697,9 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also # TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially # but that needs a lot more refactoring of Tool code potentially
if out_attachment := _interpret_content_as_attachment( if out_attachment := _interpret_content_as_attachment(result_message.content):
result_message.content
):
# NOTE: when we push this message back to the model, the model may ignore the # NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message # attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message # with the summary. We keep all generated attachments and then attach them to final message
@ -714,22 +687,14 @@ class ChatAgent(ShieldRunnerMixin):
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]: ) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
# Determine which tools to include # Determine which tools to include
agent_config_toolgroups = set( agent_config_toolgroups = set(
( (toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in self.agent_config.toolgroups for toolgroup in self.agent_config.toolgroups
) )
toolgroups_for_turn_set = ( toolgroups_for_turn_set = (
agent_config_toolgroups agent_config_toolgroups
if toolgroups_for_turn is None if toolgroups_for_turn is None
else { else {
( (toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in toolgroups_for_turn for toolgroup in toolgroups_for_turn
} }
) )
@ -759,10 +724,7 @@ class ChatAgent(ShieldRunnerMixin):
continue continue
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
for tool_def in tools.data: for tool_def in tools.data:
if ( if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
tool_name = tool_def.identifier tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search built_in_type = BuiltinTool.brave_search
if tool_name == "web_search": if tool_name == "web_search":
@ -773,9 +735,7 @@ class ChatAgent(ShieldRunnerMixin):
if tool_def_map.get(built_in_type, None): if tool_def_map.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists") raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition( tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
tool_name=built_in_type
)
tool_to_group[built_in_type] = tool_def.toolgroup_id tool_to_group[built_in_type] = tool_def.toolgroup_id
continue continue
@ -821,9 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
# Save the contents to a tempdir and use its path as a URL if code interpreter is present # Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool: if code_interpreter_tool:
for c in content_items: for c in content_items:
temp_file_path = os.path.join( temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file: with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content) temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}")) url_items.append(URL(uri=f"file://{temp_file_path}"))
@ -849,8 +807,7 @@ class ChatAgent(ShieldRunnerMixin):
# we try to load the data from the URLs and content items as a message to inference # we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context # and add it to the last message's context
input_messages[-1].context = "\n".join( input_messages[-1].context = "\n".join(
[doc.content for doc in content_items] [doc.content for doc in content_items] + await load_data_from_urls(url_items)
+ await load_data_from_urls(url_items)
) )
async def _ensure_vector_db(self, session_id: str) -> str: async def _ensure_vector_db(self, session_id: str) -> str:
@ -874,9 +831,7 @@ class ChatAgent(ShieldRunnerMixin):
return vector_db_id return vector_db_id
async def add_to_session_vector_db( async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
self, session_id: str, data: List[Document]
) -> None:
vector_db_id = await self._ensure_vector_db(session_id) vector_db_id = await self._ensure_vector_db(session_id)
documents = [ documents = [
RAGDocument( RAGDocument(
@ -931,11 +886,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else: else:
raise ValueError(f"Unsupported URL {url}") raise ValueError(f"Unsupported URL {url}")
content.append( content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
TextContentItem(
text=f'# There is a file accessible to you at "{filepath}"\n'
)
)
return ToolResponseMessage( return ToolResponseMessage(
call_id="", call_id="",

View file

@ -94,16 +94,12 @@ class MetaReferenceAgentsImpl(Agents):
try: try:
agent_config = json.loads(agent_config) agent_config = json.loads(agent_config)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError( raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
f"Could not JSON decode agent config for {agent_id}"
) from e
try: try:
agent_config = AgentConfig(**agent_config) agent_config = AgentConfig(**agent_config)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
f"Could not validate(?) agent config for {agent_id}"
) from e
return ChatAgent( return ChatAgent(
agent_id=agent_id, agent_id=agent_id,
@ -115,9 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
persistence_store=( persistence_store=(
self.persistence_store self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
if agent_config.enable_session_persistence
else self.in_memory_store
), ),
) )
@ -168,22 +162,14 @@ class MetaReferenceAgentsImpl(Agents):
async for event in agent.create_and_execute_turn(request): async for event in agent.create_and_execute_turn(request):
yield event yield event
async def get_agents_turn( async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
self, agent_id: str, session_id: str, turn_id: str turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
) -> Turn:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn) turn = json.loads(turn)
turn = Turn(**turn) turn = Turn(**turn)
return turn return turn
async def get_agents_step( async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
self, agent_id: str, session_id: str, turn_id: str, step_id: str turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
) -> AgentStepResponse:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn) turn = json.loads(turn)
turn = Turn(**turn) turn = Turn(**turn)
steps = turn.steps steps = turn.steps
@ -203,9 +189,7 @@ class MetaReferenceAgentsImpl(Agents):
turns = [] turns = []
if turn_ids: if turn_ids:
for turn_id in turn_ids: for turn_id in turn_ids:
turn = await self.persistence_store.get( turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn) turn = json.loads(turn)
turn = Turn(**turn) turn = Turn(**turn)
turns.append(turn) turns.append(turn)

View file

@ -33,9 +33,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields self.input_shields = input_shields
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields( async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
self, messages: List[Message], identifiers: List[str]
) -> None:
responses = await asyncio.gather( responses = await asyncio.gather(
*[ *[
self.safety_api.run_shield( self.safety_api.run_shield(

View file

@ -64,9 +64,7 @@ class MockInferenceAPI:
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
async def stream_response(): async def stream_response():
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -104,9 +102,7 @@ class MockInferenceAPI:
class MockSafetyAPI: class MockSafetyAPI:
async def run_shield( async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
self, shield_id: str, messages: List[Message]
) -> RunShieldResponse:
return RunShieldResponse(violation=None) return RunShieldResponse(violation=None)
@ -129,9 +125,7 @@ class MockVectorIOAPI:
class MockToolGroupsAPI: class MockToolGroupsAPI:
async def register_tool_group( async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
) -> None:
pass pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
@ -341,26 +335,21 @@ async def test_chat_agent_complex_turn(get_chat_agent):
assert len(responses) > 0 assert len(responses) > 0
step_types = [ step_types = [
response.event.payload.step_type response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
for response in responses
if hasattr(response.event.payload, "step_type")
] ]
assert StepType.shield_call in step_types, "Shield call step is missing" assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing" assert StepType.inference in step_types, "Inference step is missing"
event_types = [ event_types = [
response.event.payload.event_type response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
for response in responses
if hasattr(response.event.payload, "event_type")
] ]
assert "turn_start" in event_types, "Start event is missing" assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing" assert "turn_complete" in event_types, "Complete event is missing"
assert any( assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) "Turn complete event is missing"
for response in responses )
), "Turn complete event is missing"
turn_complete_payload = next( turn_complete_payload = next(
response.event.payload response.event.payload
for response in responses for response in responses
@ -380,9 +369,7 @@ async def test_chat_agent_complex_turn(get_chat_agent):
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools ([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
], ],
) )
async def test_chat_agent_tools( async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
):
impl = await get_agents_impl impl = await get_agents_impl
agent_config = AgentConfig( agent_config = AgentConfig(
model="test_model", model="test_model",

View file

@ -172,9 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat( dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
[dataset_impl.df, new_rows_df], ignore_index=True
)
url = str(dataset_info.dataset_def.url) url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url) parsed_url = urlparse(url)
@ -189,12 +187,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
raise ValueError("Data URL must be a base64-encoded CSV") raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False) csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
"utf-8" dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."

View file

@ -91,14 +91,10 @@ class MetaReferenceEvalImpl(
candidate = task_config.eval_candidate candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=( rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
-1 if task_config.num_examples is None else task_config.num_examples
),
) )
res = await self.evaluate_rows( res = await self.evaluate_rows(
task_id=task_id, task_id=task_id,
@ -127,9 +123,7 @@ class MetaReferenceEvalImpl(
input_messages = [UserMessage(**x) for x in input_messages] input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session( session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
agent_id, f"session-{i}"
)
session_id = session_create_response.session_id session_id = session_create_response.session_id
turn_request = dict( turn_request = dict(
@ -138,12 +132,7 @@ class MetaReferenceEvalImpl(
messages=input_messages, messages=input_messages,
stream=True, stream=True,
) )
turn_response = [ turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
final_event = turn_response[-1].event.payload final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context # check if there's a memory retrieval step and extract the context
@ -152,14 +141,10 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses: for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL: if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join( memory_rag_context = " ".join(x.text for x in tool_response.content)
x.text for x in tool_response.content
)
agent_generation = {} agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = ( agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
final_event.turn.output_message.content
)
if memory_rag_context: if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context agent_generation[ColumnName.context.value] = memory_rag_context
@ -171,9 +156,7 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate candidate = task_config.eval_candidate
assert ( assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
@ -184,15 +167,9 @@ class MetaReferenceEvalImpl(
content=input_content, content=input_content,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append( generations.append({ColumnName.generated_answer.value: response.completion_message.content})
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str( chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
x[ColumnName.chat_completion_input.value]
)
input_messages = eval(chat_completion_input_str) input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages] input_messages = [UserMessage(**x) for x in input_messages]
messages = [] messages = []
@ -204,11 +181,7 @@ class MetaReferenceEvalImpl(
messages=messages, messages=messages,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append( generations.append({ColumnName.generated_answer.value: response.completion_message.content})
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else: else:
raise ValueError("Invalid input row") raise ValueError("Invalid input row")
@ -230,10 +203,7 @@ class MetaReferenceEvalImpl(
raise ValueError(f"Invalid candidate type: {candidate.type}") raise ValueError(f"Invalid candidate type: {candidate.type}")
# scoring with generated_answer # scoring with generated_answer
score_input_rows = [ score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
input_r | generated_r
for input_r, generated_r in zip(input_rows, generations)
]
if task_config.type == "app" and task_config.scoring_params is not None: if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = { scoring_functions_dict = {
@ -241,9 +211,7 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions for scoring_fn_id in scoring_functions
} }
else: else:
scoring_functions_dict = { scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score( score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict input_rows=score_input_rows, scoring_functions=scoring_functions_dict

View file

@ -40,9 +40,7 @@ class MetaReferenceInferenceConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models] repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos): if model not in (descriptors + repos):
model_list = "\n\t".join(repos) model_list = "\n\t".join(repos)
raise ValueError( raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model return model
@classmethod @classmethod

View file

@ -83,9 +83,7 @@ class TokenResult(BaseModel):
class Llama: class Llama:
@staticmethod @staticmethod
def build( def build(
config: Union[ config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
model_id: str, model_id: str,
llama_model: Model, llama_model: Model,
): ):
@ -150,9 +148,9 @@ class Llama:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len( assert model_parallel_size == len(checkpoints), (
checkpoints f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" )
ckpt_path = checkpoints[get_model_parallel_rank()] ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -168,9 +166,9 @@ class Llama:
) )
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()
assert ( assert model_args.vocab_size == tokenizer.n_words, (
model_args.vocab_size == tokenizer.n_words f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" )
if isinstance(config, MetaReferenceQuantizedInferenceConfig): if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig): if isinstance(config.quantization, Fp8QuantizationConfig):
@ -193,10 +191,7 @@ class Llama:
model = convert_to_int4_quantized_model(model, model_args, config) model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
if ( if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
# Add a wrapper for adding hadamard transform for spinquant. # Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while # This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict. # loading the state dict.
@ -206,9 +201,7 @@ class Llama:
add_hadamard_transform_for_spinquant(model) add_hadamard_transform_for_spinquant(model)
else: else:
raise NotImplementedError( raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
"Currently int4 and fp8 are the only supported quantization methods."
)
else: else:
if device == "cuda": if device == "cuda":
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
@ -262,10 +255,7 @@ class Llama:
params = self.model.params params = self.model.params
if print_input_tokens: if print_input_tokens:
input_tokens = [ input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
log.info("Input to model -> " + self.tokenizer.decode(input_tokens)) log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens] prompt_tokens = [model_input.tokens]
@ -287,13 +277,11 @@ class Llama:
mask = model_input.vision.mask if model_input.vision is not None else [] mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension # the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = ( xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
self.model.compute_vision_tokens_masks(
batch_images=[images], batch_images=[images],
batch_masks=[mask], batch_masks=[mask],
total_len=total_len, total_len=total_len,
) )
)
pad_id = self.tokenizer.pad_id pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
@ -340,9 +328,7 @@ class Llama:
next_token = next_token.reshape(-1) next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated # only replace token if prompt has already been generated
next_token = torch.where( next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1] target = tokens[:, prev_pos + 1 : cur_pos + 1]
@ -365,17 +351,11 @@ class Llama:
reduction="none", reduction="none",
ignore_index=pad_id, ignore_index=pad_id,
) )
eos_reached |= (~input_text_mask[:, cur_pos]) & ( eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
torch.isin(next_token, stop_tokens)
)
yield TokenResult( yield TokenResult(
token=next_token[0].item(), token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()), text=self.tokenizer.decode(next_token.tolist()),
logprobs=( logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),
) )
prev_pos = cur_pos prev_pos = cur_pos
@ -388,11 +368,7 @@ class Llama:
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if ( if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content) model_input = self.formatter.encode_content(request.content)
@ -417,11 +393,7 @@ class Llama:
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if ( if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
@ -473,9 +445,7 @@ class LogitsProcessor:
self.token_enforcer = token_enforcer self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None self.mask: Optional[torch.Tensor] = None
def process_logits( def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
token_sequence = tokens[0, :].tolist() token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
@ -510,9 +480,7 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer) return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list( def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1] token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = [] regular_tokens = []

View file

@ -80,9 +80,7 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None: async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.config, model_id, llama_model
)
self.generator.start() self.generator.start()
else: else:
self.generator = Llama.build(self.config, model_id, llama_model) self.generator = Llama.build(self.config, model_id, llama_model)
@ -100,9 +98,7 @@ class MetaReferenceInferenceImpl(
"No avaible model yet, please register your requested model or add your model in the resouces first" "No avaible model yet, please register your requested model or add your model in the resouces first"
) )
elif request.model != self.model_id: elif request.model != self.model_id:
raise RuntimeError( raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
@ -184,13 +180,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs = [ logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=text, delta=text,
@ -212,9 +202,7 @@ class MetaReferenceInferenceImpl(
for x in impl(): for x in impl():
yield x yield x
async def _nonstream_completion( async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
self, request: CompletionRequest
) -> CompletionResponse:
def impl(): def impl():
tokens = [] tokens = []
logprobs = [] logprobs = []
@ -231,13 +219,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append( logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
@ -289,9 +271,7 @@ class MetaReferenceInferenceImpl(
self.check_model(request) self.check_model(request)
# augment and rewrite messages depending on the model # augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages( request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model # download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
@ -304,9 +284,7 @@ class MetaReferenceInferenceImpl(
else: else:
return await self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
def impl(): def impl():
tokens = [] tokens = []
logprobs = [] logprobs = []
@ -323,20 +301,12 @@ class MetaReferenceInferenceImpl(
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append( logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message( raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
tokens, stop_reason
)
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=raw_message.content, content=raw_message.content,
@ -352,9 +322,7 @@ class MetaReferenceInferenceImpl(
else: else:
return impl() return impl()
async def _stream_chat_completion( async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
self, request: ChatCompletionRequest
) -> AsyncGenerator:
def impl(): def impl():
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -405,13 +373,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append( logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -424,9 +386,7 @@ class MetaReferenceInferenceImpl(
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message( message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:

View file

@ -91,9 +91,7 @@ class LlamaModelParallelGenerator:
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
model_parallel_size, model_parallel_size,
init_model_cb=partial( init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
init_model_cb, self.config, self.model_id, self.llama_model
),
) )
self.group.start() self.group.start()
return self return self

View file

@ -55,47 +55,33 @@ class ProcessingMessageName(str, Enum):
class ReadyRequest(BaseModel): class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = ( type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
ProcessingMessageName.ready_request
)
class ReadyResponse(BaseModel): class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = ( type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
ProcessingMessageName.ready_response
)
class EndSentinel(BaseModel): class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = ( type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
ProcessingMessageName.end_sentinel
)
class CancelSentinel(BaseModel): class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = ( type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
ProcessingMessageName.cancel_sentinel
)
class TaskRequest(BaseModel): class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ( type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
ProcessingMessageName.task_request
)
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent] task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ( type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
ProcessingMessageName.task_response
)
result: TokenResult result: TokenResult
class ExceptionResponse(BaseModel): class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = ( type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
ProcessingMessageName.exception_response
)
error: str error: str
@ -189,9 +175,7 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(), group=get_model_parallel_group(),
) )
if isinstance(updates[0], CancelSentinel): if isinstance(updates[0], CancelSentinel):
log.info( log.info("quitting generation loop because request was cancelled")
"quitting generation loop because request was cancelled"
)
break break
if mp_rank_0(): if mp_rank_0():
@ -350,9 +334,7 @@ class ModelParallelProcessGroup:
def run_inference( def run_inference(
self, self,
req: Union[ req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
) -> Generator: ) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"

View file

@ -19,9 +19,7 @@ try:
log.info("Using efficient FP8 operators in FBGEMM.") log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError: except ImportError:
log.error( log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
raise raise
import torch import torch
@ -60,14 +58,8 @@ def ffn_swiglu(
num_tokens: Optional[Tensor] = None, num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False, is_memory_bounded: bool = False,
) -> Tensor: ) -> Tensor:
if ( if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
isinstance(w1, Fp8ScaledWeights) return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
(B, T, D) = x.shape # noqa: N806 (B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806 (HD_L, D_) = w1.shape # noqa: N806
@ -146,12 +138,8 @@ def fc_fp8_dynamic(
Single w8a8 fc layer with dynamic row-wise scaling. Single w8a8 fc layer with dynamic row-wise scaling.
""" """
if isinstance(w, Fp8RowwiseWeights): if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
x, num_tokens, activation_scale_ub y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
del xq del xq
return y return y

View file

@ -17,8 +17,7 @@ from torch import Tensor
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available", "Skip when H100 is not available",
) )
class FP8Tests(unittest.TestCase): class FP8Tests(unittest.TestCase):

View file

@ -57,9 +57,7 @@ class HadamardModule(torch.nn.Module):
return x return x
def add_hadamard_transform_for_spinquant( def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
model: torch.nn.Module, prefix: str = ""
) -> None:
""" """
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model. Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
This function recursively traverses the model's children and looks for layers that match the pattern This function recursively traverses the model's children and looks for layers that match the pattern
@ -81,12 +79,8 @@ def add_hadamard_transform_for_spinquant(
for module_name, module in model.named_children(): for module_name, module in model.named_children():
child_full_name = prefix + "." + module_name child_full_name = prefix + "." + module_name
if re.search(pattern_last_linear_ffn, child_full_name): if re.search(pattern_last_linear_ffn, child_full_name):
new_module = nn.Sequential( new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
HadamardModule(group_size=module.in_features), module
)
del module del module
setattr(model, module_name, new_module) setattr(model, module_name, new_module)
else: else:
add_hadamard_transform_for_spinquant( add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)
module, (prefix + "." if prefix else prefix) + module_name
)

View file

@ -63,12 +63,8 @@ def convert_to_fp8_quantized_model(
# Move weights to GPU with quantization # Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...") log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join( fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True) fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers: for block in model.layers:
@ -81,9 +77,7 @@ def convert_to_fp8_quantized_model(
param = getattr(block.feed_forward, key) param = getattr(block.feed_forward, key)
param.weight = load_fp8( param.weight = load_fp8(
param.weight, param.weight,
fp8_scales[ fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub, fp8_activation_scale_ub,
) )
else: else:
@ -172,9 +166,7 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
if prefix + "zeros" not in state_dict: if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero. # Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like( state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
state_dict[prefix + "scales"]
)
def forward(self, input_: torch.Tensor) -> torch.Tensor: def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_) module_out = super().forward(input_)
@ -229,9 +221,7 @@ class Int8WeightLinear(torch.nn.Linear):
bias: Whether to use bias. bias: Whether to use bias.
""" """
def __init__( def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
self, in_features: int, out_features: int, bias: bool = True, device=None
) -> None:
super().__init__(in_features, out_features, bias, device=device) super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook) self._register_load_state_dict_pre_hook(self.load_hook)
@ -295,9 +285,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
del module del module
setattr(model, module_name, quantized_module) setattr(model, module_name, quantized_module)
else: else:
_prepare_model_int4_weight_int8_dynamic_activation( _prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
module, group_size, lora_rank, lora_scale
)
return model return model
@ -321,9 +309,7 @@ def convert_to_int4_quantized_model(
group_size = model_args.quantization_args.group_size group_size = model_args.quantization_args.group_size
if group_size is None: if group_size is None:
raise ValueError( raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
"'group_size' cannot be None in 'quantization_args'. Please specify it."
)
if model_args.lora_args is None: if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA. # Certain quantized models (e.g., SpinQuant) may not have LoRA.
@ -333,8 +319,6 @@ def convert_to_int4_quantized_model(
lora_rank = model_args.lora_args.rank lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation( _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
model, group_size, lora_rank, lora_scale
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device) return model.to(device)

View file

@ -76,9 +76,9 @@ def main(
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len( assert model_parallel_size == len(checkpoints), (
checkpoints f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" )
ckpt_path = checkpoints[get_model_parallel_rank()] ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -90,9 +90,9 @@ def main(
**params, **params,
) )
tokenizer = Tokenizer(model_path=tokenizer_path) tokenizer = Tokenizer(model_path=tokenizer_path)
assert ( assert model_args.vocab_size == tokenizer.n_words, (
model_args.vocab_size == tokenizer.n_words f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" )
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype # load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor) torch.set_default_tensor_type(torch.BFloat16Tensor)
@ -106,9 +106,7 @@ def main(
torch.set_default_tensor_type(torch.cuda.HalfTensor) torch.set_default_tensor_type(torch.cuda.HalfTensor)
log.info(ckpt_path) log.info(ckpt_path)
assert ( assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
fp8_scales = {} fp8_scales = {}
for block in model.layers: for block in model.layers:
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
@ -122,9 +120,7 @@ def main(
) )
with torch.inference_mode(): with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight) block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[ fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8( fp8_weight = quantize_fp8(
block.feed_forward.w3.weight, block.feed_forward.w3.weight,
@ -133,9 +129,7 @@ def main(
) )
with torch.inference_mode(): with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight) block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[ fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_weight = quantize_fp8( fp8_weight = quantize_fp8(
block.feed_forward.w2.weight, block.feed_forward.w2.weight,
@ -144,13 +138,9 @@ def main(
) )
with torch.inference_mode(): with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight) block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[ fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales_path = os.path.join( fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
torch.save(fp8_scales, fp8_scales_path) torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join( ckpt_path = os.path.join(

View file

@ -10,7 +10,6 @@ from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel): class SentenceTransformersInferenceConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls) -> Dict[str, Any]: def sample_run_config(cls) -> Dict[str, Any]:
return {} return {}

View file

@ -53,7 +53,5 @@ class VLLMConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models] repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos): if model not in (descriptors + repos):
model_list = "\n\t".join(repos) model_list = "\n\t".join(repos)
raise ValueError( raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model return model

View file

@ -176,13 +176,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params) log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid() request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt( prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
request, self.config.model, self.formatter
)
vllm_sampling_params = self._sampling_params(request.sampling_params) vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate( results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
prompt, vllm_sampling_params, request_id
)
if stream: if stream:
return self._stream_chat_completion(request, results_generator) return self._stream_chat_completion(request, results_generator)
else: else:
@ -230,12 +226,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
) )
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response( async for chunk in process_chat_completion_stream_response(stream, self.formatter):
stream, self.formatter
):
yield chunk yield chunk
async def embeddings( async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -47,6 +47,4 @@ async def validate_input_dataset_schema(
if dataset_type not in EXPECTED_DATASET_SCHEMA: if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.") raise ValueError(f"Dataset type {dataset_type} is not supported.")
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
)

View file

@ -42,9 +42,7 @@ class TorchtuneCheckpointer:
self._model_type = ModelType[model_type] self._model_type = ModelType[model_type]
self._output_dir = output_dir self._output_dir = output_dir
# get ckpt paths # get ckpt paths
self._checkpoint_path = Path.joinpath( self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
self._checkpoint_dir, self._checkpoint_file
)
def load_checkpoint(self) -> Dict[str, Any]: def load_checkpoint(self) -> Dict[str, Any]:
""" """
@ -57,13 +55,9 @@ class TorchtuneCheckpointer:
llama3_vision_meta_to_tune, llama3_vision_meta_to_tune,
) )
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune( state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
model_state_dict
)
else: else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune( state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
model_state_dict
)
# llama3_2 has tied weights, so we need to remove the output.weight key # llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2: if self._model_type == ModelType.LLAMA3_2:
@ -82,10 +76,7 @@ class TorchtuneCheckpointer:
epoch: int, epoch: int,
adapter_only: bool = False, adapter_only: bool = False,
) -> str: ) -> str:
model_file_path = ( model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path.mkdir(parents=True, exist_ok=True) model_file_path.mkdir(parents=True, exist_ok=True)
@ -116,22 +107,13 @@ class TorchtuneCheckpointer:
llama3_vision_tune_to_meta, llama3_vision_tune_to_meta,
) )
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
model_state_dict
)
else: else:
# llama3_2 has tied weights, so we need to add the output.weight key # llama3_2 has tied weights, so we need to add the output.weight key
if ( if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
self._model_type == ModelType.LLAMA3_2 model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta( state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
model_state_dict
)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth") model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")

View file

@ -15,18 +15,13 @@ from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct( def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
sample: Mapping[str, Any] assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
) -> Mapping[str, Any]: "Invalid input row"
assert ( )
ColumnName.chat_completion_input.value in sample
and ColumnName.expected_answer.value in sample
), "Invalid input row"
input_messages = eval(str(sample[ColumnName.chat_completion_input.value])) input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
assert ( assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
len(input_messages) == 1
), "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0] input_message = input_messages[0]
assert "content" in input_message, "content not found in input message" assert "content" in input_message, "content not found in input message"
@ -48,13 +43,9 @@ def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str
roles = [] roles = []
conversations = [] conversations = []
for message in dialog: for message in dialog:
assert ( assert "role" in message and "content" in message, "role and content must in message"
"role" in message and "content" in message
), "role and content must in message"
roles.append(message["role"]) roles.append(message["role"])
conversations.append( conversations.append({"from": role_map[message["role"]], "value": message["content"]})
{"from": role_map[message["role"]], "value": message["content"]}
)
assert roles[0] == "user", "first message must be from user" assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant" assert "assistant" in roles, "at least 1 message should be from assistant"

View file

@ -61,8 +61,7 @@ class SFTDataset(Dataset):
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys()) keys_str = ", ".join(tokenized_dict.keys())
error_message = ( error_message = (
"model_transform returned the following keys: " f"model_transform returned the following keys: {keys_str}. Must return 'tokens' and 'mask' as keys."
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
) )
raise ValueError(error_message) raise ValueError(error_message)

View file

@ -119,9 +119,7 @@ class TorchtunePostTrainingImpl:
return ListPostTrainingJobsResponse(data=self.jobs_list) return ListPostTrainingJobsResponse(data=self.jobs_list)
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
async def get_training_job_status( async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status: if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid] return self.jobs_status[job_uuid]
return None return None
@ -131,12 +129,8 @@ class TorchtunePostTrainingImpl:
raise NotImplementedError("Job cancel is not implemented yet") raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts( async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
self, job_uuid: str
) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict: if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, []) checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse( return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
job_uuid=job_uuid, checkpoints=checkpoints
)
return None return None

View file

@ -94,9 +94,7 @@ class LoraFinetuningSingleDevice:
self.job_uuid = job_uuid self.job_uuid = job_uuid
self.training_config = training_config self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig): if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError( raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda") self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device) self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -105,10 +103,7 @@ class LoraFinetuningSingleDevice:
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [ paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths): if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original" checkpoint_dir = checkpoint_dir / "original"
@ -123,9 +118,7 @@ class LoraFinetuningSingleDevice:
else: else:
model = resolve_model(self.model_id) model = resolve_model(self.model_id)
if model is None: if model is None:
raise ValueError( raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -196,9 +189,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer() self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.") log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer( self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
optimizer_config=self.training_config.optimizer_config
)
log.info("Optimizer is initialized.") log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss() self._loss_fn = CEWithChunkedOutputLoss()
@ -226,13 +217,8 @@ class LoraFinetuningSingleDevice:
# by the dataloader and the max_steps_per_epoch param set by the user and is used # by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader # for logging and tracking training state. This should be computed after the dataloader
# has been setup # has been setup
self._steps_per_epoch = ( self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
len(self._training_dataloader) // self._gradient_accumulation_steps if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch
@ -246,9 +232,7 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.") log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation # Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full( self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
async def _setup_model( async def _setup_model(
self, self,
@ -282,13 +266,9 @@ class LoraFinetuningSingleDevice:
set_trainable_params(model, self.adapter_params) set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing: if enable_activation_checkpointing:
training.set_activation_checkpointing( training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
base_missing, base_unexpected = model.load_state_dict( base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
base_model_state_dict, strict=False
)
# This is for any adapters that need to be initialized after base weights # This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA). # have been loaded (e.g. DoRA).
@ -297,9 +277,7 @@ class LoraFinetuningSingleDevice:
if hasattr(m, "initialize_dora_magnitude"): if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude() m.initialize_dora_magnitude()
if lora_weights_state_dict: if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict( lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
lora_weights_state_dict, strict=False
)
else: else:
lora_missing, lora_unexpected = None, None lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora( validate_missing_and_unexpected_for_lora(
@ -313,14 +291,10 @@ class LoraFinetuningSingleDevice:
) )
# Validate model adapter params were loaded in with the expected dtype # Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype( training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
self.adapter_params.items(), dtype=self._dtype
)
# activation offloading # activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager( self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
model, enable_activation_offloading
)
memory_stats = training.get_memory_stats(device=self._device) memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats) training.log_memory_stats(memory_stats)
@ -456,9 +430,7 @@ class LoraFinetuningSingleDevice:
# Shift labels to compute loss # Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels. # But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack( labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list): if not isinstance(logits, list):
labels = labels.reshape(-1) labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1)) logits = logits.reshape(-1, logits.size(-1))
@ -487,9 +459,7 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs): for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs # Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True # in case shuffle is True
metric_logger = DiskLogger( metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
self._training_sampler.set_epoch(curr_epoch) self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0 loss_to_log = 0.0
@ -497,8 +467,7 @@ class LoraFinetuningSingleDevice:
for idx, batch in enumerate(self._training_dataloader): for idx, batch in enumerate(self._training_dataloader):
if ( if (
self.max_steps_per_epoch is not None self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps) and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
== self.max_steps_per_epoch
): ):
break break
@ -506,9 +475,7 @@ class LoraFinetuningSingleDevice:
# Calculate the number of unmasked tokens in the current batch # Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step # and increment the total number of tokens seen in the step
current_num_tokens = ( current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens # Loss is normalized by default so we multiply by the number of tokens
@ -533,9 +500,7 @@ class LoraFinetuningSingleDevice:
loss_to_log = running_loss.item() / num_tokens loss_to_log = running_loss.item() / num_tokens
pbar.update(1) pbar.update(1)
pbar.set_description( pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
time_per_step = time.perf_counter() - t0 time_per_step = time.perf_counter() - t0
log_dict = { log_dict = {

View file

@ -67,10 +67,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
violation = SafetyViolation( violation = SafetyViolation(
violation_level=(ViolationLevel.ERROR), violation_level=(ViolationLevel.ERROR),
user_message="Sorry, I found security concerns in the code.", user_message="Sorry, I found security concerns in the code.",
metadata={ metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
"violation_type": ",".join(
[issue.pattern_id for issue in result.issues_found]
)
},
) )
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)

View file

@ -10,9 +10,7 @@ from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps): async def get_provider_impl(config: LlamaGuardConfig, deps):
from .llama_guard import LlamaGuardSafetyImpl from .llama_guard import LlamaGuardSafetyImpl
assert isinstance( assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
config, LlamaGuardConfig
), f"Unexpected config type: {type(config)}"
impl = LlamaGuardSafetyImpl(config, deps) impl = LlamaGuardSafetyImpl(config, deps)
await impl.initialize() await impl.initialize()

View file

@ -102,8 +102,7 @@ LLAMA_GUARD_MODEL_IDS = {
} }
MODEL_TO_SAFETY_CATEGORIES_MAP = { MODEL_TO_SAFETY_CATEGORIES_MAP = {
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES "meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
+ [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES, "meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES, "meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
} }
@ -133,9 +132,7 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
- If unsafe, a second line must include a comma-separated list of violated categories.""" - If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template( PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
@ -233,9 +230,7 @@ class LlamaGuardShield:
if messages[0].role != Role.user.value: if messages[0].role != Role.user.value:
raise ValueError("Messages must start with user") raise ValueError("Messages must start with user")
if len(messages) >= 2 and ( if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
messages[0].role == Role.user.value and messages[1].role == Role.user.value
):
messages = messages[1:] messages = messages[1:]
for i in range(1, len(messages)): for i in range(1, len(messages)):
@ -263,10 +258,7 @@ class LlamaGuardShield:
stream=True, stream=True,
): ):
event = chunk.event event = chunk.event
if ( if event.event_type == ChatCompletionResponseEventType.progress and event.delta.type == "text":
event.event_type == ChatCompletionResponseEventType.progress
and event.delta.type == "text"
):
content += event.delta.text content += event.delta.text
content = content.strip() content = content.strip()
@ -313,10 +305,7 @@ class LlamaGuardShield:
categories = self.get_safety_categories() categories = self.get_safety_categories()
categories_str = "\n".join(categories) categories_str = "\n".join(categories)
conversations_str = "\n\n".join( conversations_str = "\n\n".join(
[ [f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}" for m in messages]
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages
]
) )
return PROMPT_TEMPLATE.substitute( return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(), agent_type=messages[-1].role.capitalize(),

View file

@ -46,9 +46,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id != PROMPT_GUARD_MODEL: if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError( raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
)
async def run_shield( async def run_shield(
self, self,
@ -71,9 +69,7 @@ class PromptGuardShield:
threshold: float = 0.9, threshold: float = 0.9,
temperature: float = 1.0, temperature: float = 1.0,
): ):
assert ( assert model_dir is not None, "Must provide a model directory for prompt injection shield"
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0: if temperature <= 0:
raise ValueError("Temperature must be greater than 0") raise ValueError("Temperature must be greater than 0")
@ -85,9 +81,7 @@ class PromptGuardShield:
# load model and tokenizer # load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained( self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
model_dir, device_map=self.device
)
async def run(self, messages: List[Message]) -> RunShieldResponse: async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1] message = messages[-1]
@ -117,10 +111,7 @@ class PromptGuardShield:
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", "violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
}, },
) )
elif ( elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
self.config.guard_type == PromptGuardType.jailbreak.value
and score_malicious > self.threshold
):
violation = SafetyViolation( violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, violation_level=ViolationLevel.ERROR,
violation_type=f"prompt_injection:malicious={score_malicious}", violation_type=f"prompt_injection:malicious={score_malicious}",

View file

@ -54,15 +54,11 @@ class BasicScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [ scoring_fn_defs_list = [
fn_def fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
] ]
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith( assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list return scoring_fn_defs_list
@ -76,9 +72,7 @@ class BasicScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -108,12 +102,8 @@ class BasicScoringImpl(
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None) scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score( score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
input_rows, scoring_fn_id, scoring_fn_params agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -32,9 +32,7 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
scoring_params: Optional[ScoringFnParams] = None, scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row." assert "expected_answer" in input_row, "Expected answer not found in input row."
assert ( assert "generated_answer" in input_row, "Generated answer not found in input row."
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]

View file

@ -18,7 +18,5 @@ equality = ScoringFn(
provider_id="basic", provider_id="basic",
provider_resource_id="equality", provider_resource_id="equality",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams( params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
aggregation_functions=[AggregationFunctionType.accuracy]
),
) )

View file

@ -55,9 +55,7 @@ MULTILINGUAL_ANSWER_REGEXES = [
r"Àṣàyàn\s*:", r"Àṣàyàn\s*:",
] ]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
regex_parser_multiple_choice_answer = ScoringFn( regex_parser_multiple_choice_answer = ScoringFn(
identifier="basic::regex_parser_multiple_choice_answer", identifier="basic::regex_parser_multiple_choice_answer",
@ -66,10 +64,7 @@ regex_parser_multiple_choice_answer = ScoringFn(
provider_id="basic", provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer", provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams( params=RegexParserScoringFnParams(
parsing_regexes=[ parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
aggregation_functions=[AggregationFunctionType.accuracy], aggregation_functions=[AggregationFunctionType.accuracy],
), ),
) )

View file

@ -18,7 +18,5 @@ subset_of = ScoringFn(
return_type=NumberType(), return_type=NumberType(),
provider_id="basic", provider_id="basic",
provider_resource_id="subset-of", provider_resource_id="subset-of",
params=BasicScoringFnParams( params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
aggregation_functions=[AggregationFunctionType.accuracy]
),
) )

View file

@ -33,17 +33,14 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None, scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
assert ( assert scoring_fn_identifier is not None, "Scoring function identifier not found."
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None: if scoring_params is not None:
fn_def.params = scoring_params fn_def.params = scoring_params
assert ( assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
fn_def.params is not None f"RegexParserScoringFnParams not found for {fn_def}."
and fn_def.params.type == ScoringFnParamsType.regex_parser.value )
), f"RegexParserScoringFnParams not found for {fn_def}."
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]

View file

@ -124,12 +124,10 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.braintrust_evaluators = { self.braintrust_evaluators = {
entry.identifier: entry.evaluator entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
@ -139,16 +137,14 @@ class BraintrustScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()] scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith( assert f.identifier.startswith("braintrust"), (
"braintrust" "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " )
return scoring_fn_defs_list return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError( raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
"Registering scoring function not allowed for braintrust provider"
)
async def set_api_key(self) -> None: async def set_api_key(self) -> None:
# api key is in the request headers # api key is in the request headers
@ -171,17 +167,13 @@ class BraintrustScoringImpl(
await self.set_api_key() await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema( validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
res = await self.score( res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
if save_results_dataset: if save_results_dataset:
# TODO: persist and register dataset on to server for reading # TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset() # self.datasets_api.register_dataset()
@ -222,13 +214,8 @@ class BraintrustScoringImpl(
if scoring_fn_id not in self.supported_fn_defs_registry: if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [ score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
await self.score_row(input_row, scoring_fn_id) aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
for input_row in input_rows
]
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
# override scoring_fn params if provided # override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None: if scoring_functions[scoring_fn_id] is not None:

View file

@ -21,7 +21,5 @@ answer_correctness_fn_def = ScoringFn(
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="answer-correctness", provider_resource_id="answer-correctness",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams( params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
aggregation_functions=[AggregationFunctionType.average]
),
) )

View file

@ -20,7 +20,5 @@ answer_relevancy_fn_def = ScoringFn(
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="answer-relevancy", provider_resource_id="answer-relevancy",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams( params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
aggregation_functions=[AggregationFunctionType.average]
),
) )

View file

@ -20,7 +20,5 @@ answer_similarity_fn_def = ScoringFn(
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="answer-similarity", provider_resource_id="answer-similarity",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams( params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
aggregation_functions=[AggregationFunctionType.average]
),
) )

Some files were not shown because too many files have changed in this diff Show more