Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -1,7 +1,8 @@
[flake8]
# Suggested config from pytorch that we can adapt
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
max-line-length = 120
lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"]
line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
# N812 ignored because import torch.nn.functional as F is PyTorch convention
@ -9,23 +10,28 @@ max-line-length = 120
# E731 allow usage of assigning lambda expressions
# E701 let black auto-format statements on one line
# E704 let black auto-format statements on one line
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,E701,E704
lint.ignore = [
"E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841",
"C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701",
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023",
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
"EXE001",
# random naming hints don't need
N802,
"N802",
# these ignores are from flake8-bugbear; please fix!
B007,B008,B950
optional-ascii-coding = True
exclude =
./.git,
./docs/*,
./build,
./scripts,
./venv,
*.pyi,
.pre-commit-config.yaml,
*.md,
.flake8
"B007", "B008"
]
exclude = [
"./.git",
"./docs/*",
"./build",
"./scripts",
"./venv",
"*.pyi",
".pre-commit-config.yaml",
"*.md",
".flake8"
]

View file

@ -77,7 +77,7 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant",
# Enable both RAG and tool usage
toolgroups=[
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}}.
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
"builtin::code_interpreter",
],
# Configure safety
@ -86,13 +86,9 @@ agent_config = AgentConfig(
# Control the inference loop
max_infer_iters=5,
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 0.7,
"top_p": 0.95
},
"max_tokens": 2048
}
"strategy": {"type": "top_p", "temperature": 0.7, "top_p": 0.95},
"max_tokens": 2048,
},
)
agent = Agent(client, agent_config)
@ -101,11 +97,13 @@ session_id = agent.create_session("monitored_session")
# Stream the agent's execution steps
response = agent.create_turn(
messages=[{"role": "user", "content": "Analyze this code and run it"}],
attachments=[{
"content": "https://raw.githubusercontent.com/example/code.py",
"mime_type": "text/plain"
}],
session_id=session_id
attachments=[
{
"content": "https://raw.githubusercontent.com/example/code.py",
"mime_type": "text/plain",
}
],
session_id=session_id,
)
# Monitor each step of execution

View file

@ -15,6 +15,7 @@ This first example walks you through how to evaluate a model candidate served by
```python
import datasets
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
eval_rows = ds.to_pandas().to_dict(orient="records")
@ -43,7 +44,7 @@ system_message = {
client.eval_tasks.register(
eval_task_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
)
response = client.eval.evaluate_rows(
@ -62,9 +63,9 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096,
"repeat_penalty": 1.0,
},
"system_message": system_message
}
}
"system_message": system_message,
},
},
)
```
@ -88,7 +89,7 @@ _ = client.datasets.register(
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
}
},
)
eval_rows = client.datasetio.get_rows_paginated(
@ -101,7 +102,7 @@ eval_rows = client.datasetio.get_rows_paginated(
client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"]
scoring_functions=["llm-as-judge::405b-simpleqa"],
)
response = client.eval.evaluate_rows(
@ -120,8 +121,8 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096,
"repeat_penalty": 1.0,
},
}
}
},
},
)
```
@ -144,14 +145,14 @@ agent_config = {
{
"type": "brave_search",
"engine": "tavily",
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
}
],
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False
"enable_session_persistence": False,
}
response = client.eval.evaluate_rows(
@ -163,7 +164,7 @@ response = client.eval.evaluate_rows(
"eval_candidate": {
"type": "agent",
"config": agent_config,
}
}
},
},
)
```

View file

@ -13,7 +13,7 @@ Here's how to set up basic evaluation:
response = client.eval_tasks.register(
eval_task_id="my_eval",
dataset_id="my_dataset",
scoring_functions=["accuracy", "relevance"]
scoring_functions=["accuracy", "relevance"],
)
# Run evaluation
@ -21,16 +21,10 @@ job = client.eval.run_eval(
task_id="my_eval",
task_config={
"type": "app",
"eval_candidate": {
"type": "agent",
"config": agent_config
}
}
"eval_candidate": {"type": "agent", "config": agent_config},
},
)
# Get results
result = client.eval.job_result(
task_id="my_eval",
job_id=job.job_id
)
result = client.eval.job_result(task_id="my_eval", job_id=job.job_id)
```

View file

@ -34,15 +34,16 @@ chunks = [
{
"document_id": "doc1",
"content": "Your document text here",
"mime_type": "text/plain"
"mime_type": "text/plain",
},
...
...,
]
client.vector_io.insert(vector_db_id, chunks)
# You can then query for these chunks
chunks_response = client.vector_io.query(vector_db_id, query="What do you know about...")
chunks_response = client.vector_io.query(
vector_db_id, query="What do you know about..."
)
```
### Using the RAG Tool
@ -81,7 +82,6 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python
# Configure agent with memory
agent_config = AgentConfig(
model="Llama3.2-3B-Instruct",
@ -91,9 +91,9 @@ agent_config = AgentConfig(
"name": "builtin::rag",
"args": {
"vector_db_ids": [vector_db_id],
}
},
}
]
],
)
agent = Agent(client, agent_config)
@ -101,25 +101,21 @@ session_id = agent.create_session("rag_session")
# Initial document ingestion
response = agent.create_turn(
messages=[{
"role": "user",
"content": "I am providing some documents for reference."
}],
messages=[
{"role": "user", "content": "I am providing some documents for reference."}
],
documents=[
dict(
content="https://raw.githubusercontent.com/example/doc.rst",
mime_type="text/plain"
mime_type="text/plain",
)
],
session_id=session_id
session_id=session_id,
)
# Query with RAG
response = agent.create_turn(
messages=[{
"role": "user",
"content": "What are the key topics in the documents?"
}],
session_id=session_id
messages=[{"role": "user", "content": "What are the key topics in the documents?"}],
session_id=session_id,
)
```

View file

@ -5,15 +5,11 @@ Safety is a critical component of any AI application. Llama Stack provides a Shi
```python
# Register a safety shield
shield_id = "content_safety"
client.shields.register(
shield_id=shield_id,
provider_shield_id="llama-guard-basic"
)
client.shields.register(shield_id=shield_id, provider_shield_id="llama-guard-basic")
# Run content through shield
response = client.safety.run_shield(
shield_id=shield_id,
messages=[{"role": "user", "content": "User message here"}]
shield_id=shield_id, messages=[{"role": "user", "content": "User message here"}]
)
if response.violation:

View file

@ -8,24 +8,16 @@ The telemetry system supports three main types of events:
- **Unstructured Log Events**: Free-form log messages with severity levels
```python
unstructured_log_event = UnstructuredLogEvent(
message="This is a log message",
severity=LogSeverity.INFO
message="This is a log message", severity=LogSeverity.INFO
)
```
- **Metric Events**: Numerical measurements with units
```python
metric_event = MetricEvent(
metric="my_metric",
value=10,
unit="count"
)
metric_event = MetricEvent(metric="my_metric", value=10, unit="count")
```
- **Structured Log Events**: System events like span start/end. Extensible to add more structured log types.
```python
structured_log_event = SpanStartPayload(
name="my_span",
parent_span_id="parent_span_id"
)
structured_log_event = SpanStartPayload(name="my_span", parent_span_id="parent_span_id")
```
### Spans and Traces

View file

@ -35,7 +35,7 @@ Example client SDK call to register a "websearch" toolgroup that is provided by
client.toolgroups.register(
toolgroup_id="builtin::websearch",
provider_id="brave-search",
args={"max_results": 5}
args={"max_results": 5},
)
```
@ -50,8 +50,7 @@ The Code Interpreter allows execution of Python code within a controlled environ
```python
# Register Code Interpreter tool group
client.toolgroups.register(
toolgroup_id="builtin::code_interpreter",
provider_id="code_interpreter"
toolgroup_id="builtin::code_interpreter", provider_id="code_interpreter"
)
```
@ -68,16 +67,14 @@ The WolframAlpha tool provides access to computational knowledge through the Wol
```python
# Register WolframAlpha tool group
client.toolgroups.register(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha"
toolgroup_id="builtin::wolfram_alpha", provider_id="wolfram-alpha"
)
```
Example usage:
```python
result = client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha",
args={"query": "solve x^2 + 2x + 1 = 0"}
tool_name="wolfram_alpha", args={"query": "solve x^2 + 2x + 1 = 0"}
)
```
@ -90,10 +87,7 @@ The Memory tool enables retrieval of context from various types of memory banks
client.toolgroups.register(
toolgroup_id="builtin::memory",
provider_id="memory",
args={
"max_chunks": 5,
"max_tokens_in_context": 4096
}
args={"max_chunks": 5, "max_tokens_in_context": 4096},
)
```
@ -136,9 +130,7 @@ config = AgentConfig(
toolgroups=[
"builtin::websearch",
],
client_tools=[
ToolDef(name="client_tool", description="Client provided tool")
]
client_tools=[ToolDef(name="client_tool", description="Client provided tool")],
)
```
@ -167,9 +159,9 @@ Example tool definition:
"name": "query",
"parameter_type": "string",
"description": "The query to search for",
"required": True
"required": True,
}
]
],
}
```
@ -179,8 +171,7 @@ Tools can be invoked using the `invoke_tool` method:
```python
result = client.tool_runtime.invoke_tool(
tool_name="web_search",
kwargs={"query": "What is the capital of France?"}
tool_name="web_search", kwargs={"query": "What is the capital of France?"}
)
```

View file

@ -1,9 +1,9 @@
# Using Llama Stack as a Library
If you are planning to use an external service for Inference (even Ollama or TGI counts as external), it is often easier to use Llama Stack as a library. This avoids the overhead of setting up a server.
```python
```bash
# setup
pip install llama-stack
uv pip install llama-stack
llama stack build --template together --image-type venv
```
@ -13,7 +13,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient(
"ollama",
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data = {"tavily_search_api_key": os.environ['TAVILY_SEARCH_API_KEY']}
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
)
await client.initialize()
```

View file

@ -96,18 +96,26 @@ Here is a simple example to perform chat completions using the SDK.
```python
import os
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
return LlamaStackClient(
base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}"
)
def create_library_client(template="ollama"):
from llama_stack import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient(template)
client.initialize()
return client
client = create_library_client() # or create_http_client() depending on the environment you picked
client = (
create_library_client()
) # or create_http_client() depending on the environment you picked
# List available models
models = client.models.list()
@ -120,8 +128,8 @@ response = client.inference.chat_completion(
model_id=os.environ["INFERENCE_MODEL"],
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a haiku about coding"}
]
{"role": "user", "content": "Write a haiku about coding"},
],
)
print(response.completion_message.content)
```
@ -139,7 +147,9 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types import Document
client = create_library_client() # or create_http_client() depending on the environment you picked
client = (
create_library_client()
) # or create_http_client() depending on the environment you picked
# Documents to be used for RAG
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
@ -174,12 +184,12 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant",
enable_session_persistence=False,
# Define tools available to the agent
toolgroups = [
toolgroups=[
{
"name": "builtin::rag",
"args" : {
"vector_db_ids": [vector_db_id],
}
"name": "builtin::rag",
"args": {
"vector_db_ids": [vector_db_id],
},
}
],
)
@ -193,7 +203,7 @@ user_prompts = [
# Run the agent loop by calling the `create_turn` method
for prompt in user_prompts:
cprint(f'User> {prompt}', 'green')
cprint(f"User> {prompt}", "green")
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,

View file

@ -51,6 +51,7 @@ This first example walks you through how to evaluate a model candidate served by
```python
import datasets
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
eval_rows = ds.to_pandas().to_dict(orient="records")
@ -79,7 +80,7 @@ system_message = {
client.eval_tasks.register(
eval_task_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
)
response = client.eval.evaluate_rows(
@ -98,9 +99,9 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096,
"repeat_penalty": 1.0,
},
"system_message": system_message
}
}
"system_message": system_message,
},
},
)
```
@ -124,7 +125,7 @@ _ = client.datasets.register(
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
}
},
)
eval_rows = client.datasetio.get_rows_paginated(
@ -137,7 +138,7 @@ eval_rows = client.datasetio.get_rows_paginated(
client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"]
scoring_functions=["llm-as-judge::405b-simpleqa"],
)
response = client.eval.evaluate_rows(
@ -156,8 +157,8 @@ response = client.eval.evaluate_rows(
"max_tokens": 4096,
"repeat_penalty": 1.0,
},
}
}
},
},
)
```
@ -180,14 +181,14 @@ agent_config = {
{
"type": "brave_search",
"engine": "tavily",
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
}
],
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False
"enable_session_persistence": False,
}
response = client.eval.evaluate_rows(
@ -199,8 +200,8 @@ response = client.eval.evaluate_rows(
"eval_candidate": {
"type": "agent",
"config": agent_config,
}
}
},
},
)
```
@ -237,7 +238,9 @@ GENERATED_RESPONSE: {generated_answer}
EXPECTED_RESPONSE: {expected_answer}
"""
input_query = "What are the top 5 topics that were explained? Only list succinct bullet points."
input_query = (
"What are the top 5 topics that were explained? Only list succinct bullet points."
)
generated_answer = """
Here are the top 5 topics that were explained in the documentation for Torchtune:
@ -268,7 +271,9 @@ scoring_params = {
"braintrust::factuality": None,
}
response = client.scoring.score(input_rows=dataset_rows, scoring_functions=scoring_params)
response = client.scoring.score(
input_rows=dataset_rows, scoring_functions=scoring_params
)
```
## Running Evaluations via CLI

View file

@ -33,7 +33,11 @@ from llama_stack_client.types import (
Types:
```python
from llama_stack_client.types import ListToolGroupsResponse, ToolGroup, ToolgroupListResponse
from llama_stack_client.types import (
ListToolGroupsResponse,
ToolGroup,
ToolgroupListResponse,
)
```
Methods:
@ -444,7 +448,11 @@ Methods:
Types:
```python
from llama_stack_client.types import EvalTask, ListEvalTasksResponse, EvalTaskListResponse
from llama_stack_client.types import (
EvalTask,
ListEvalTasksResponse,
EvalTaskListResponse,
)
```
Methods:

View file

@ -224,7 +224,7 @@ client = LlamaStackClient(base_url="http://localhost:5001")
response = client.inference.chat_completion(
messages=[
{"role": "system", "content": "You are a friendly assistant."},
{"role": "user", "content": "Write a two-sentence poem about llama."}
{"role": "user", "content": "Write a two-sentence poem about llama."},
],
model_id=INFERENCE_MODEL,
)

View file

@ -86,9 +86,7 @@ class ShieldCallStep(StepCommon):
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
)
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
vector_db_ids: str
inserted_context: InterleavedContent
@ -184,9 +182,7 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
AgentTurnResponseEventType.step_start.value
)
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -194,9 +190,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
AgentTurnResponseEventType.step_complete.value
)
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
step_type: StepType
step_id: str
step_details: Step
@ -206,9 +200,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
AgentTurnResponseEventType.step_progress.value
)
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
step_type: StepType
step_id: str
@ -217,17 +209,13 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
AgentTurnResponseEventType.turn_start.value
)
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
AgentTurnResponseEventType.turn_complete.value
)
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
turn: Turn
@ -329,9 +317,7 @@ class Agents(Protocol):
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
)
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
async def get_agents_turn(
self,
agent_id: str,

View file

@ -63,9 +63,7 @@ class EventLogger:
if isinstance(chunk, ToolResponseMessage):
yield (
chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
)
continue
@ -81,17 +79,12 @@ class EventLogger:
step_type = event.payload.step_type
# handle safety
if (
step_type == StepType.shield_call
and event_type == EventType.step_complete.value
):
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
violation = event.payload.step_details.violation
if not violation:
yield (
event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
LogEvent(role=step_type, content="No Violation", color="magenta"),
)
else:
yield (
@ -110,9 +103,7 @@ class EventLogger:
# TODO: Currently this event is never received
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
LogEvent(role=step_type, content="", end="", color="yellow"),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
@ -125,9 +116,7 @@ class EventLogger:
):
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
LogEvent(role=step_type, content="", end="", color="yellow"),
)
delta = event.payload.delta
@ -161,9 +150,7 @@ class EventLogger:
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(
response.tool_calls[0], tool_prompt_format
)
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
else:
content = response.content
yield (
@ -202,10 +189,7 @@ class EventLogger:
),
)
if (
step_type == StepType.memory_retrieval
and event_type == EventType.step_complete.value
):
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"

View file

@ -39,6 +39,4 @@ class DatasetIO(Protocol):
) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -63,9 +63,7 @@ class AppEvalTaskConfig(BaseModel):
EvalTaskConfig = register_schema(
Annotated[
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
],
Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")],
name="EvalTaskConfig",
)

View file

@ -245,9 +245,7 @@ class JsonSchemaResponseFormat(BaseModel):
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
"""
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
json_schema: Dict[str, Any]
@ -406,9 +404,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -89,9 +89,7 @@ class QATFinetuningConfig(BaseModel):
AlgorithmConfig = register_schema(
Annotated[
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
],
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
name="AlgorithmConfig",
)
@ -204,14 +202,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]: ...
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(
self, job_uuid: str
) -> Optional[PostTrainingJobArtifactsResponse]: ...
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...

View file

@ -23,9 +23,7 @@ class ResourceType(Enum):
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)
identifier: str = Field(description="Unique identifier for this resource in llama stack")
provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider",
@ -34,6 +32,4 @@ class Resource(BaseModel):
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(
description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)"
)
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)")

View file

@ -43,9 +43,7 @@ class AggregationFunctionType(Enum):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
ScoringFnParamsType.llm_as_judge.value
)
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
@ -60,9 +58,7 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = (
ScoringFnParamsType.regex_parser.value
)
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
@ -112,9 +108,7 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
@property
def scoring_fn_id(self) -> str:
@ -141,9 +135,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(

View file

@ -102,9 +102,7 @@ class StructuredLogType(Enum):
@json_schema_type
class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = (
StructuredLogType.SPAN_START.value
)
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
name: str
parent_span_id: Optional[str] = None
@ -190,9 +188,7 @@ class QuerySpanTreeResponse(BaseModel):
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST")
async def log_event(
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
) -> None: ...
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
@webmethod(route="/telemetry/traces", method="GET")
async def query_traces(

View file

@ -64,9 +64,7 @@ RAGQueryGeneratorConfig = register_schema(
class RAGQueryConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(
default=DefaultRAGQueryGeneratorConfig()
)
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5

View file

@ -150,8 +150,6 @@ class ToolRuntime(Protocol):
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...

View file

@ -147,9 +147,7 @@ class ParallelDownloader:
"follow_redirects": True,
}
async def retry_with_exponential_backoff(
self, task: DownloadTask, func, *args, **kwargs
):
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
last_exception = None
for attempt in range(task.max_retries):
try:
@ -166,13 +164,9 @@ class ParallelDownloader:
continue
raise last_exception
async def get_file_info(
self, client: httpx.AsyncClient, task: DownloadTask
) -> None:
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
async def _get_info():
response = await client.head(
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
)
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
response.raise_for_status()
return response
@ -201,14 +195,10 @@ class ParallelDownloader:
return False
return os.path.getsize(task.output_file) == task.total_size
async def download_chunk(
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
) -> None:
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream(
"GET", task.url, headers=headers, **self.client_options
) as response:
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
@ -225,8 +215,7 @@ class ParallelDownloader:
await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e:
raise DownloadError(
f"Failed to download chunk {start}-{end} after "
f"{task.max_retries} attempts: {str(e)}"
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
) from e
async def prepare_download(self, task: DownloadTask) -> None:
@ -244,9 +233,7 @@ class ParallelDownloader:
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(
f"[green]Already downloaded {task.output_file}[/green]"
)
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
self.progress.update(task.task_id, completed=task.total_size)
return
@ -259,9 +246,7 @@ class ParallelDownloader:
current_pos = task.downloaded_size
while current_pos < task.total_size:
chunk_end = min(
current_pos + chunk_size - 1, task.total_size - 1
)
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1
@ -273,18 +258,12 @@ class ParallelDownloader:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
)
raise DownloadError(
f"Download failed for {task.output_file}: {str(e)}"
) from e
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
try:
total_remaining_size = sum(
task.total_size - task.downloaded_size for task in tasks
)
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free
@ -314,9 +293,7 @@ class ParallelDownloader:
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task(
desc, total=task.total_size, completed=task.downloaded_size
)
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
@ -332,9 +309,7 @@ class ParallelDownloader:
if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks:
self.console.print(
f"[red]- {Path(task.output_file).name}: {error}[/red]"
)
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
raise DownloadError(f"{len(failed_tasks)} downloads failed")
@ -396,11 +371,7 @@ def _meta_download(
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
tasks.append(
DownloadTask(
url=url, output_file=output_file, total_size=total_size, max_retries=3
)
)
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
@ -446,14 +417,10 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
console.print(
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
)
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
@ -471,9 +438,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(
max_concurrent_downloads=max_concurrent_downloads
)
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))

View file

@ -47,33 +47,20 @@ class ModelPromptFormat(Subcommand):
# Only Llama 3.1 and 3.2 are supported
supported_model_ids = [
m
for m in CoreModelId
if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_str = "\n".join([m.value for m in supported_model_ids])
try:
model_id = CoreModelId(args.model_name)
except ValueError:
self.parser.error(
f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}"
)
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
if model_id not in supported_model_ids:
self.parser.error(
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
)
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
llama_3_1_file = (
importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
)
llama_3_2_text_file = (
importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
)
llama_3_2_vision_file = (
importlib.resources.files("llama_models")
/ "llama3_2/vision_prompt_format.md"
)
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read()

View file

@ -17,16 +17,12 @@ class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
model_id: str = "Prompt-Guard-86M"
description: str = (
"Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
)
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
is_featured: bool = False
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
max_seq_length: int = 2048
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict)
recommended_sampling_params: Optional[SamplingParams] = None

View file

@ -56,9 +56,7 @@ def available_templates_specs() -> Dict[str, BuildConfig]:
return template_specs
def run_stack_build_command(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> None:
def run_stack_build_command(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
if args.list_templates:
return _run_template_list_cmd()
@ -129,11 +127,7 @@ def run_stack_build_command(
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
]
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
api_provider = prompt(
"> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers),
@ -156,9 +150,7 @@ def run_stack_build_command(
description=description,
)
build_config = BuildConfig(
image_type=image_type, distribution_spec=distribution_spec
)
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
else:
with open(args.config, "r") as f:
try:
@ -179,9 +171,7 @@ def run_stack_build_command(
if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}")
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps:
@ -206,9 +196,7 @@ def _generate_run_config(
"""
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
container_image=(
image_name if build_config.image_type == ImageType.container.value else None
),
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
image_name=image_name,
apis=apis,
providers={},
@ -228,13 +216,9 @@ def _generate_run_config(
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(
__distro_dir__=f"distributions/{image_name}"
)
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
else:
config = {}
@ -269,9 +253,7 @@ def _run_stack_build_command_from_build_config(
image_name = f"distribution-{template_name}"
else:
if not image_name:
raise ValueError(
"Please specify an image name when building a container image without a template"
)
raise ValueError("Please specify an image name when building a container image without a template")
elif build_config.image_type == ImageType.conda.value:
if not image_name:
raise ValueError("Please specify an image name when building a conda image")
@ -299,10 +281,7 @@ def _run_stack_build_command_from_build_config(
if template_name:
# copy run.yaml from template to build_dir instead of generating it again
template_path = (
importlib.resources.files("llama_stack")
/ f"templates/{template_name}/run.yaml"
)
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file)

View file

@ -82,31 +82,21 @@ class StackRun(Subcommand):
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
)
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
)
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to container dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml"
)
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(
DISTRIBS_BASE_DIR
/ f"llamastack-{args.config}"
/ f"{args.config}-run.yaml"
)
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
@ -119,15 +109,8 @@ class StackRun(Subcommand):
config = parse_and_maybe_upgrade_config(config_dict)
if config.container_image:
script = (
importlib.resources.files("llama_stack")
/ "distribution/start_container.sh"
)
image_name = (
f"distribution-{template_name}"
if template_name
else config.container_image
)
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
image_name = f"distribution-{template_name}" if template_name else config.container_image
run_args = [script, image_name]
else:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
@ -145,11 +128,7 @@ class StackRun(Subcommand):
if env_name == "base":
return os.environ.get("CONDA_PREFIX")
# Get conda environments info
conda_env_info = json.loads(
subprocess.check_output(
["conda", "info", "--envs", "--json"]
).decode()
)
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
@ -173,10 +152,7 @@ class StackRun(Subcommand):
)
return
script = (
importlib.resources.files("llama_stack")
/ "distribution/start_conda_env.sh"
)
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
run_args = [
script,
image_name,

View file

@ -22,11 +22,7 @@ def format_row(row, col_widths):
if line.strip() == "":
lines.append("")
else:
lines.extend(
textwrap.wrap(
line, width, break_long_words=False, replace_whitespace=False
)
)
lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
return lines
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]

View file

@ -41,9 +41,7 @@ def up_to_date_config():
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
""".format(
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
)
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
)
@ -83,9 +81,7 @@ def old_config():
telemetry:
provider_type: noop
config: {{}}
""".format(
built_at=datetime.now().isoformat()
)
""".format(built_at=datetime.now().isoformat())
)
@ -108,10 +104,7 @@ def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert all(
api in result.providers
for api in ["inference", "safety", "memory", "telemetry"]
)
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "meta-reference"
assert "llama_guard_shield" in safety_provider.config

View file

@ -72,9 +72,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
return checksums
def verify_files(
model_dir: Path, checksums: Dict[str, str], console: Console
) -> List[VerificationResult]:
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
results = []
with Progress(

View file

@ -58,22 +58,14 @@ def get_provider_dependencies(
for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)]
providers = (
provider_or_providers
if isinstance(provider_or_providers, list)
else [provider_or_providers]
)
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different  not great
provider_type = (
provider if isinstance(provider, str) else provider.provider_type
)
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
provider_spec = providers_for_api[provider_type]
deps.extend(provider_spec.pip_packages)
@ -109,19 +101,13 @@ def build_image(
image_name: str,
template_or_config: str,
):
container_base = (
build_config.distribution_spec.container_image or "python:3.10-slim"
)
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.container.value:
script = str(
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
)
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
args = [
script,
template_or_config,
@ -132,9 +118,7 @@ def build_image(
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.conda.value:
script = str(
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
)
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [
script,
str(image_name),
@ -142,9 +126,7 @@ def build_image(
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.venv.value:
script = str(
importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
)
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
args = [
script,
str(image_name),

View file

@ -68,9 +68,7 @@ def create_api_client_class(protocol) -> Type:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert return_type, (
f"Could not extract return type for {sig.return_annotation}"
)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
@ -87,9 +85,7 @@ def create_api_client_class(protocol) -> Type:
webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation)
assert return_type, (
f"Could not extract return type for {sig.return_annotation}"
)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
@ -204,9 +200,7 @@ async def example(model: str = None):
if not model:
model = "Llama3.2-3B-Instruct"
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
cprint(f"User>{message.content}", "green")
stream = True

View file

@ -26,9 +26,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
@ -47,9 +45,7 @@ def configure_single_provider(
)
def configure_api_providers(
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
is_nux = len(config.providers) == 0
if is_nux:
@ -87,9 +83,7 @@ def configure_api_providers(
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("")
else:
# we are newly configuring this API
@ -114,11 +108,7 @@ def configure_api_providers(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
provider_type=provider_type,
config={},
),
@ -137,11 +127,7 @@ def upgrade_from_routing_table(
def get_providers(entries):
return [
Provider(
provider_id=(
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
provider_type=entry["provider_type"],
config=entry["config"],
)

View file

@ -163,9 +163,7 @@ a default SQLite store will be used.""",
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field(
description="The distribution spec to build including API providers. "
)
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
image_type: str = Field(
default="conda",
description="Type of package to build (conda | container | venv)",

View file

@ -55,9 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -154,9 +154,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def sync_generator():
try:
async_stream = loop.run_until_complete(
self.async_client.request(*args, **kwargs)
)
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
@ -181,9 +179,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(
sink for sink in current_sinks if sink != "console"
)
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
@ -202,9 +198,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def initialize(self):
try:
self.impls = await construct_stack(
self.config, self.custom_provider_registry
)
self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e:
cprint(_e.msg, "red")
cprint(
@ -247,9 +241,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][
_convert_path_to_regex(endpoint.route)
] = func
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
self.endpoint_impls = endpoint_impls
return True
@ -266,9 +258,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized")
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
if stream:
response = await self._call_streaming(
@ -408,9 +398,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(
self, path: str, method: str, body: Optional[dict] = None
) -> dict:
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
if not body:
return {}
@ -425,7 +413,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic(
param.annotation, value
)
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
return converted_body

View file

@ -115,9 +115,7 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
providers_with_specs = {}
@ -125,16 +123,12 @@ async def resolve_impls(
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if provider.provider_type not in provider_registry[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
@ -145,9 +139,7 @@ async def resolve_impls(
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies] + [
a.value for a in p.optional_api_dependencies
]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
@ -158,9 +150,7 @@ async def resolve_impls(
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
for info in builtin_automatically_routed_apis():
@ -197,9 +187,7 @@ async def resolve_impls(
)
}
sorted_providers = topological_sort(
{k: v.values() for k, v in providers_with_specs.items()}
)
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
@ -237,9 +225,7 @@ async def resolve_impls(
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(
provider,
@ -336,10 +322,7 @@ async def instantiate_provider(
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
@ -367,19 +350,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
log.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -85,9 +85,7 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
vector_db_id, chunks, ttl_seconds
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
@ -95,9 +93,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
vector_db_id, query, params
)
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
class InferenceRouter(Inference):
@ -123,9 +119,7 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion(
self,
@ -143,9 +137,7 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
params = dict(
model_id=model_id,
messages=messages,
@ -176,9 +168,7 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@ -202,9 +192,7 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
@ -231,9 +219,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
return await self.routing_table.register_shield(
shield_id, provider_shield_id, provider_id, params
)
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
@ -268,9 +254,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
@ -305,9 +289,7 @@ class ScoringRouter(Scoring):
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -328,9 +310,7 @@ class ScoringRouter(Scoring):
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -381,9 +361,7 @@ class EvalRouter(Eval):
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
async def job_cancel(
self,
@ -420,9 +398,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"query_from_memory"
).query(content, vector_db_ids, query_config)
return await self.routing_table.get_provider_impl("query_from_memory").query(
content, vector_db_ids, query_config
)
async def insert(
self,
@ -430,9 +408,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"insert_into_memory"
).insert(documents, vector_db_id, chunk_size_in_tokens)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
def __init__(
self,
@ -460,6 +438,4 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -94,9 +94,7 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
@ -131,9 +129,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
@ -171,9 +167,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
@ -183,13 +177,9 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -244,9 +234,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
@ -266,9 +254,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(
data=await self.get_all_with_type(ResourceType.shield.value)
)
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
@ -340,9 +326,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {embedding_model} does not have an embedding dimension"
)
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
@ -364,9 +348,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(
data=await self.get_all_with_type(ResourceType.dataset.value)
)
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -411,9 +393,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -510,12 +490,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs:
tools.append(

View file

@ -43,9 +43,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction
)
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue

View file

@ -76,9 +76,7 @@ async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
@ -178,9 +176,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
@ -190,11 +186,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
]
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
@ -202,15 +194,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
# Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [
(
param.replace(
annotation=Annotated[
param.annotation, FastapiPath(..., title=param.name)
]
)
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
for param in new_params[1:]
]
@ -244,12 +230,8 @@ class ClientVersionMiddleware:
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(
map(int, client_version.split(".")[:2])
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
@ -267,9 +249,7 @@ class ClientVersionMiddleware:
}
}
).encode()
await send(
{"type": "http.response.body", "body": error_msg}
)
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
@ -296,9 +276,7 @@ def main():
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
)
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
parser.add_argument(
"--env",
action="append",
@ -323,9 +301,7 @@ def main():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
elif args.template:
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
)
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
@ -383,9 +359,7 @@ def main():
impl_method = getattr(impl, endpoint.name)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=UserWarning, module="pydantic._internal._fields"
)
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
@ -416,9 +390,7 @@ def main():
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
]
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
return params

View file

@ -110,9 +110,7 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
@ -187,9 +185,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
return key, value
except ValueError as e:
raise ValueError(
@ -202,20 +198,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
await register_resources(run_config, impls)
return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = (
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
with importlib.resources.as_file(template_path) as path:
if not path.exists():

View file

@ -25,9 +25,7 @@ class DistributionRegistry(Protocol):
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
async def update(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider: ...
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
@ -61,9 +59,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None:
pass
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
# Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache")
@ -72,12 +68,8 @@ class DiskDistributionRegistry(DistributionRegistry):
values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values)
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
if not json_str:
return None
@ -143,9 +135,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def initialize(self) -> None:
await self._ensure_initialized()
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), None)
async def get_all(self) -> List[RoutableObjectWithProvider]:
@ -153,9 +143,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async with self._locked_cache() as cache:
return list(cache.values())
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
await self._ensure_initialized()
cache_key = (type, identifier)
@ -197,9 +185,7 @@ async def create_dist_registry(
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()

View file

@ -161,9 +161,7 @@ async def test_duplicate_provider_registration(config):
result = await cached_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert (
result.embedding_model == original_vector_db.embedding_model
) # Original values preserved
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
@ -193,14 +191,9 @@ async def test_get_all_objects(config):
# Verify each vector_db was stored correctly
for original_vector_db in test_vector_dbs:
matching_vector_dbs = [
v for v in all_results if v.identifier == original_vector_db.identifier
]
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
assert len(matching_vector_dbs) == 1
stored_vector_db = matching_vector_dbs[0]
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert (
stored_vector_db.embedding_dimension
== original_vector_db.embedding_dimension
)
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension

View file

@ -22,15 +22,11 @@ def main():
)
# Playground pages
chat_page = st.Page(
"page/playground/chat.py", title="Chat", icon="💬", default=True
)
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
# Distribution pages
resources_page = st.Page(
"page/distribution/resources.py", title="Resources", icon="🔍", default=False
)
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
provider_page = st.Page(
"page/distribution/providers.py",
title="API Providers",

View file

@ -23,15 +23,11 @@ class LlamaStackApi:
},
)
def run_scoring(
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
):
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score(
input_rows=[row], scoring_functions=scoring_params
)
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
llama_stack_api = LlamaStackApi()

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
def datasets():
st.header("Datasets")
datasets_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
}
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -12,12 +12,8 @@ def eval_tasks():
# Eval Tasks Section
st.header("Eval Tasks")
eval_tasks_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
}
eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
if len(eval_tasks_info) > 0:
selected_eval_task = st.selectbox(
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
)
selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
st.json(eval_tasks_info[selected_eval_task], expanded=True)

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {
m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()
}
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -11,12 +11,7 @@ from modules.api import llama_stack_api
def scoring_functions():
st.header("Scoring Functions")
scoring_functions_info = {
s.identifier: s.to_dict()
for s in llama_stack_api.client.scoring_functions.list()
}
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
selected_scoring_function = st.selectbox(
"Select a scoring function", list(scoring_functions_info.keys())
)
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
st.json(scoring_functions_info[selected_scoring_function], expanded=True)

View file

@ -12,9 +12,7 @@ def shields():
# Shields Section
st.header("Shields")
shields_info = {
s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()
}
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
st.json(shields_info[selected_shield])

View file

@ -10,14 +10,10 @@ from modules.api import llama_stack_api
def vector_dbs():
st.header("Vector Databases")
vector_dbs_info = {
v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()
}
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox(
"Select a vector database", list(vector_dbs_info.keys())
)
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
st.json(vector_dbs_info[selected_vector_db])
else:
st.info("No vector databases found")

View file

@ -14,7 +14,6 @@ from modules.utils import process_dataset
def application_evaluation_page():
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Scoring)")
@ -83,9 +82,7 @@ def application_evaluation_page():
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
)
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
@ -128,9 +125,7 @@ def application_evaluation_page():
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(
score_res.to_json(),
expanded=2,

View file

@ -195,7 +195,6 @@ def run_evaluation_3():
# Add run button and handle evaluation
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.rows
@ -233,9 +232,7 @@ def run_evaluation_3():
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!")
@ -247,7 +244,6 @@ def run_evaluation_3():
def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)")

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [
model.identifier for model in available_models if model.model_type == "llm"
]
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
@ -128,6 +126,4 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
full_response = response
message_placeholder.markdown(full_response.completion_message.content)
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -74,9 +74,7 @@ def rag_chat_page():
)
available_models = llama_stack_api.client.models.list()
available_models = [
model.identifier for model in available_models if model.model_type == "llm"
]
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
@ -137,9 +135,7 @@ def rag_chat_page():
dict(
name="builtin::rag",
args={
"vector_db_ids": [
vector_db_id for vector_db_id in selected_vector_dbs
],
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
},
)
],
@ -186,9 +182,7 @@ def rag_chat_page():
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
st.session_state.messages.append({"role": "assistant", "content": full_response})
rag_chat_page()

View file

@ -8,9 +8,7 @@ import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(
os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))
)
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"

View file

@ -31,15 +31,11 @@ def is_list_of_primitives(field_type):
def is_basemodel_without_fields(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
)
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
def can_recurse(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
)
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
def get_literal_values(field):
@ -72,7 +68,7 @@ def is_discriminated_union(typ) -> bool:
if isinstance(typ, FieldInfo):
return typ.discriminator
else:
if not (get_origin(typ) is Annotated):
if get_origin(typ) is not Annotated:
return False
args = get_args(typ)
return len(args) >= 2 and args[1].discriminator
@ -116,9 +112,7 @@ def prompt_for_discriminated_union(
chosen_type = type_map[discriminator_value]
log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator) != discriminator_value
):
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
@ -134,9 +128,7 @@ def prompt_for_discriminated_union(
#
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage.
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
@ -150,17 +142,11 @@ def prompt_for_config(
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = (
getattr(existing_config, field_name) if existing_config else None
)
existing_value = getattr(existing_config, field_name) if existing_config else None
if existing_value:
default_value = existing_value
else:
default_value = (
field.default
if not isinstance(field.default, PydanticUndefinedType)
else None
)
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
is_required = field.is_required
# Skip fields with Literal type
@ -183,15 +169,11 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except KeyError:
log.error(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
continue
if is_discriminated_union(field):
config_data[field_name] = prompt_for_discriminated_union(
field_name, field, existing_value
)
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
continue
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
@ -202,9 +184,7 @@ def prompt_for_config(
nested_type = get_non_none_type(field_type)
log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type)
):
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
@ -260,16 +240,12 @@ def prompt_for_config(
try:
value = json.loads(user_input)
if not isinstance(value, list):
raise ValueError(
"Input must be a JSON-encoded list"
)
raise ValueError("Input must be a JSON-encoded list")
element_type = get_args(field_type)[0]
value = [element_type(item) for item in value]
except json.JSONDecodeError:
log.error(
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
)
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
continue
except ValueError as e:
log.error(f"{str(e)}")
@ -279,20 +255,14 @@ def prompt_for_config(
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError(
"Input must be a JSON-encoded dictionary"
)
raise ValueError("Input must be a JSON-encoded dictionary")
except json.JSONDecodeError:
log.error(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
continue
# Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass(
field_type, BaseModel
):
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
# For nested BaseModels, we assume a dictionary-like string input
import ast
@ -301,16 +271,12 @@ def prompt_for_config(
value = field_type(user_input)
except ValueError:
log.error(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
continue
try:
# Validate the field using our manual validation function
validated_value = manually_validate_field(
config_type, field_name, value
)
validated_value = manually_validate_field(config_type, field_name, value)
config_data[field_name] = validated_value
break
except ValueError as e:

View file

@ -11,9 +11,7 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(
config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -74,9 +74,7 @@ log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -153,9 +151,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
@ -206,14 +202,9 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
@ -388,9 +379,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
@ -408,9 +397,7 @@ class ChatAgent(ShieldRunnerMixin):
vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(
query_config
)
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
@ -438,9 +425,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content(
[msg.content for msg in input_messages]
),
content=concat_interleaved_content([msg.content for msg in input_messages]),
vector_db_ids=vector_db_ids,
query_config=query_config,
)
@ -472,9 +457,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
@ -511,9 +494,7 @@ class ChatAgent(ShieldRunnerMixin):
self.agent_config.model,
input_messages,
tools=[
tool
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
response_format=self.agent_config.response_format,
@ -560,12 +541,8 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
stop_reason = stop_reason or StopReason.out_of_tokens
@ -667,9 +644,7 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
@ -697,9 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := _interpret_content_as_attachment(
result_message.content
):
if out_attachment := _interpret_content_as_attachment(result_message.content):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
@ -714,22 +687,14 @@ class ChatAgent(ShieldRunnerMixin):
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
for toolgroup in self.agent_config.toolgroups
)
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
for toolgroup in toolgroups_for_turn
}
)
@ -759,10 +724,7 @@ class ChatAgent(ShieldRunnerMixin):
continue
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
for tool_def in tools.data:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
@ -773,9 +735,7 @@ class ChatAgent(ShieldRunnerMixin):
if tool_def_map.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition(
tool_name=built_in_type
)
tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
@ -821,9 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
@ -849,8 +807,7 @@ class ChatAgent(ShieldRunnerMixin):
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items]
+ await load_data_from_urls(url_items)
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
)
async def _ensure_vector_db(self, session_id: str) -> str:
@ -874,9 +831,7 @@ class ChatAgent(ShieldRunnerMixin):
return vector_db_id
async def add_to_session_vector_db(
self, session_id: str, data: List[Document]
) -> None:
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
@ -931,11 +886,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
TextContentItem(
text=f'# There is a file accessible to you at "{filepath}"\n'
)
)
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
return ToolResponseMessage(
call_id="",

View file

@ -94,16 +94,12 @@ class MetaReferenceAgentsImpl(Agents):
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(
f"Could not JSON decode agent config for {agent_id}"
) from e
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(
f"Could not validate(?) agent config for {agent_id}"
) from e
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
@ -115,9 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
else self.in_memory_store
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
),
)
@ -168,22 +162,14 @@ class MetaReferenceAgentsImpl(Agents):
async for event in agent.create_and_execute_turn(request):
yield event
async def get_agents_turn(
self, agent_id: str, session_id: str, turn_id: str
) -> Turn:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
return turn
async def get_agents_step(
self, agent_id: str, session_id: str, turn_id: str, step_id: str
) -> AgentStepResponse:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
steps = turn.steps
@ -203,9 +189,7 @@ class MetaReferenceAgentsImpl(Agents):
turns = []
if turn_ids:
for turn_id in turn_ids:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)

View file

@ -33,9 +33,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], identifiers: List[str]
) -> None:
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(

View file

@ -64,9 +64,7 @@ class MockInferenceAPI:
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -104,9 +102,7 @@ class MockInferenceAPI:
class MockSafetyAPI:
async def run_shield(
self, shield_id: str, messages: List[Message]
) -> RunShieldResponse:
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
return RunShieldResponse(violation=None)
@ -129,9 +125,7 @@ class MockVectorIOAPI:
class MockToolGroupsAPI:
async def register_tool_group(
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
) -> None:
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
@ -341,26 +335,21 @@ async def test_chat_agent_complex_turn(get_chat_agent):
assert len(responses) > 0
step_types = [
response.event.payload.step_type
for response in responses
if hasattr(response.event.payload, "step_type")
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
]
assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing"
event_types = [
response.event.payload.event_type
for response in responses
if hasattr(response.event.payload, "event_type")
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
]
assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing"
assert any(
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
for response in responses
), "Turn complete event is missing"
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
"Turn complete event is missing"
)
turn_complete_payload = next(
response.event.payload
for response in responses
@ -380,9 +369,7 @@ async def test_chat_agent_complex_turn(get_chat_agent):
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
],
)
async def test_chat_agent_tools(
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
):
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",

View file

@ -172,9 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url)
@ -189,12 +187,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
"utf-8"
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."

View file

@ -91,14 +91,10 @@ class MetaReferenceEvalImpl(
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(
-1 if task_config.num_examples is None else task_config.num_examples
),
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
)
res = await self.evaluate_rows(
task_id=task_id,
@ -127,9 +123,7 @@ class MetaReferenceEvalImpl(
input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(
agent_id, f"session-{i}"
)
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
session_id = session_create_response.session_id
turn_request = dict(
@ -138,12 +132,7 @@ class MetaReferenceEvalImpl(
messages=input_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context
@ -152,14 +141,10 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(
x.text for x in tool_response.content
)
memory_rag_context = " ".join(x.text for x in tool_response.content)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
)
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
@ -171,9 +156,7 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
generations = []
for x in tqdm(input_rows):
@ -184,15 +167,9 @@ class MetaReferenceEvalImpl(
content=input_content,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str(
x[ColumnName.chat_completion_input.value]
)
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
@ -204,11 +181,7 @@ class MetaReferenceEvalImpl(
messages=messages,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
else:
raise ValueError("Invalid input row")
@ -230,10 +203,7 @@ class MetaReferenceEvalImpl(
raise ValueError(f"Invalid candidate type: {candidate.type}")
# scoring with generated_answer
score_input_rows = [
input_r | generated_r
for input_r, generated_r in zip(input_rows, generations)
]
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
@ -241,9 +211,7 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict

View file

@ -40,9 +40,7 @@ class MetaReferenceInferenceConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model
@classmethod

View file

@ -83,9 +83,7 @@ class TokenResult(BaseModel):
class Llama:
@staticmethod
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
model_id: str,
llama_model: Model,
):
@ -150,9 +148,9 @@ class Llama:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -168,9 +166,9 @@ class Llama:
)
tokenizer = Tokenizer.get_instance()
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
@ -193,10 +191,7 @@ class Llama:
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if (
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
@ -206,9 +201,7 @@ class Llama:
add_hadamard_transform_for_spinquant(model)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."
)
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
else:
if device == "cuda":
if torch.cuda.is_bf16_supported():
@ -262,10 +255,7 @@ class Llama:
params = self.model.params
if print_input_tokens:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
@ -287,12 +277,10 @@ class Llama:
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
@ -340,9 +328,7 @@ class Llama:
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
@ -365,17 +351,11 @@ class Llama:
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
prev_pos = cur_pos
@ -388,11 +368,7 @@ class Llama:
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content)
@ -417,11 +393,7 @@ class Llama:
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
@ -473,9 +445,7 @@ class LogitsProcessor:
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
@ -510,9 +480,7 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []

View file

@ -80,9 +80,7 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start()
else:
self.generator = Llama.build(self.config, model_id, llama_model)
@ -100,9 +98,7 @@ class MetaReferenceInferenceImpl(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def unregister_model(self, model_id: str) -> None:
pass
@ -184,13 +180,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield CompletionResponseStreamChunk(
delta=text,
@ -212,9 +202,7 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
@ -231,13 +219,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
@ -289,9 +271,7 @@ class MetaReferenceInferenceImpl(
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
@ -304,9 +284,7 @@ class MetaReferenceInferenceImpl(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
def impl():
tokens = []
logprobs = []
@ -323,20 +301,12 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
@ -352,9 +322,7 @@ class MetaReferenceInferenceImpl(
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -405,13 +373,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -424,9 +386,7 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:

View file

@ -91,9 +91,7 @@ class LlamaModelParallelGenerator:
self.group = ModelParallelProcessGroup(
model_parallel_size,
init_model_cb=partial(
init_model_cb, self.config, self.model_id, self.llama_model
),
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
)
self.group.start()
return self

View file

@ -55,47 +55,33 @@ class ProcessingMessageName(str, Enum):
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: TokenResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
error: str
@ -189,9 +175,7 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
log.info(
"quitting generation loop because request was cancelled"
)
log.info("quitting generation loop because request was cancelled")
break
if mp_rank_0():
@ -350,9 +334,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Union[
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -19,9 +19,7 @@ try:
log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError:
log.error(
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
@ -60,14 +58,8 @@ def ffn_swiglu(
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (
isinstance(w1, Fp8ScaledWeights)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
@ -146,12 +138,8 @@ def fc_fp8_dynamic(
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
del xq
return y

View file

@ -17,8 +17,7 @@ from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):

View file

@ -57,9 +57,7 @@ class HadamardModule(torch.nn.Module):
return x
def add_hadamard_transform_for_spinquant(
model: torch.nn.Module, prefix: str = ""
) -> None:
def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
"""
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
This function recursively traverses the model's children and looks for layers that match the pattern
@ -81,12 +79,8 @@ def add_hadamard_transform_for_spinquant(
for module_name, module in model.named_children():
child_full_name = prefix + "." + module_name
if re.search(pattern_last_linear_ffn, child_full_name):
new_module = nn.Sequential(
HadamardModule(group_size=module.in_features), module
)
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
del module
setattr(model, module_name, new_module)
else:
add_hadamard_transform_for_spinquant(
module, (prefix + "." if prefix else prefix) + module_name
)
add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)

View file

@ -63,12 +63,8 @@ def convert_to_fp8_quantized_model(
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
@ -81,9 +77,7 @@ def convert_to_fp8_quantized_model(
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
fp8_activation_scale_ub,
)
else:
@ -172,9 +166,7 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(
state_dict[prefix + "scales"]
)
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
@ -229,9 +221,7 @@ class Int8WeightLinear(torch.nn.Linear):
bias: Whether to use bias.
"""
def __init__(
self, in_features: int, out_features: int, bias: bool = True, device=None
) -> None:
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
@ -295,9 +285,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(
module, group_size, lora_rank, lora_scale
)
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
return model
@ -321,9 +309,7 @@ def convert_to_int4_quantized_model(
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError(
"'group_size' cannot be None in 'quantization_args'. Please specify it."
)
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
@ -333,8 +319,6 @@ def convert_to_int4_quantized_model(
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(
model, group_size, lora_rank, lora_scale
)
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)

View file

@ -76,9 +76,9 @@ def main(
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -90,9 +90,9 @@ def main(
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
@ -106,9 +106,7 @@ def main(
torch.set_default_tensor_type(torch.cuda.HalfTensor)
log.info(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
@ -122,9 +120,7 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
@ -133,9 +129,7 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
@ -144,13 +138,9 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_scales_path = os.path.join(
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(

View file

@ -10,7 +10,6 @@ from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {}

View file

@ -53,7 +53,5 @@ class VLLMConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model

View file

@ -176,13 +176,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(
request, self.config.model, self.formatter
)
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream:
return self._stream_chat_completion(request, results_generator)
else:
@ -230,12 +226,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def embeddings(
self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse:
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -47,6 +47,4 @@ async def validate_input_dataset_schema(
if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")
validate_dataset_schema(
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
)
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])

View file

@ -42,9 +42,7 @@ class TorchtuneCheckpointer:
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
def load_checkpoint(self) -> Dict[str, Any]:
"""
@ -57,13 +55,9 @@ class TorchtuneCheckpointer:
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
@ -82,10 +76,7 @@ class TorchtuneCheckpointer:
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
model_file_path.mkdir(parents=True, exist_ok=True)
@ -116,22 +107,13 @@ class TorchtuneCheckpointer:
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")

View file

@ -15,18 +15,13 @@ from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any]
) -> Mapping[str, Any]:
assert (
ColumnName.chat_completion_input.value in sample
and ColumnName.expected_answer.value in sample
), "Invalid input row"
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
"Invalid input row"
)
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
assert (
len(input_messages) == 1
), "llama stack intruct dataset format only supports 1 user message"
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"
@ -48,13 +43,9 @@ def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str
roles = []
conversations = []
for message in dialog:
assert (
"role" in message and "content" in message
), "role and content must in message"
assert "role" in message and "content" in message, "role and content must in message"
roles.append(message["role"])
conversations.append(
{"from": role_map[message["role"]], "value": message["content"]}
)
conversations.append({"from": role_map[message["role"]], "value": message["content"]})
assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant"

View file

@ -61,8 +61,7 @@ class SFTDataset(Dataset):
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
f"model_transform returned the following keys: {keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)

View file

@ -119,9 +119,7 @@ class TorchtunePostTrainingImpl:
return ListPostTrainingJobsResponse(data=self.jobs_list)
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]:
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
@ -131,12 +129,8 @@ class TorchtunePostTrainingImpl:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> Optional[PostTrainingJobArtifactsResponse]:
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None

View file

@ -94,9 +94,7 @@ class LoraFinetuningSingleDevice:
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -105,10 +103,7 @@ class LoraFinetuningSingleDevice:
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
@ -123,9 +118,7 @@ class LoraFinetuningSingleDevice:
else:
model = resolve_model(self.model_id)
if model is None:
raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -196,9 +189,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer(
optimizer_config=self.training_config.optimizer_config
)
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
@ -226,13 +217,8 @@ class LoraFinetuningSingleDevice:
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = (
len(self._training_dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
@ -246,9 +232,7 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
async def _setup_model(
self,
@ -282,13 +266,9 @@ class LoraFinetuningSingleDevice:
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
@ -297,9 +277,7 @@ class LoraFinetuningSingleDevice:
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
)
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
else:
lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora(
@ -313,14 +291,10 @@ class LoraFinetuningSingleDevice:
)
# Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
@ -456,9 +430,7 @@ class LoraFinetuningSingleDevice:
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
@ -487,9 +459,7 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
@ -497,8 +467,7 @@ class LoraFinetuningSingleDevice:
for idx, batch in enumerate(self._training_dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
):
break
@ -506,9 +475,7 @@ class LoraFinetuningSingleDevice:
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
@ -533,9 +500,7 @@ class LoraFinetuningSingleDevice:
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
time_per_step = time.perf_counter() - t0
log_dict = {

View file

@ -67,10 +67,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
violation = SafetyViolation(
violation_level=(ViolationLevel.ERROR),
user_message="Sorry, I found security concerns in the code.",
metadata={
"violation_type": ",".join(
[issue.pattern_id for issue in result.issues_found]
)
},
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
)
return RunShieldResponse(violation=violation)

View file

@ -10,9 +10,7 @@ from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(
config, LlamaGuardConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
impl = LlamaGuardSafetyImpl(config, deps)
await impl.initialize()

View file

@ -102,8 +102,7 @@ LLAMA_GUARD_MODEL_IDS = {
}
MODEL_TO_SAFETY_CATEGORIES_MAP = {
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
+ [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
}
@ -133,9 +132,7 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
@ -233,9 +230,7 @@ class LlamaGuardShield:
if messages[0].role != Role.user.value:
raise ValueError("Messages must start with user")
if len(messages) >= 2 and (
messages[0].role == Role.user.value and messages[1].role == Role.user.value
):
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
messages = messages[1:]
for i in range(1, len(messages)):
@ -263,10 +258,7 @@ class LlamaGuardShield:
stream=True,
):
event = chunk.event
if (
event.event_type == ChatCompletionResponseEventType.progress
and event.delta.type == "text"
):
if event.event_type == ChatCompletionResponseEventType.progress and event.delta.type == "text":
content += event.delta.text
content = content.strip()
@ -313,10 +305,7 @@ class LlamaGuardShield:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages
]
[f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),

View file

@ -46,9 +46,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
)
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
async def run_shield(
self,
@ -71,9 +69,7 @@ class PromptGuardShield:
threshold: float = 0.9,
temperature: float = 1.0,
):
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
assert model_dir is not None, "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
@ -85,9 +81,7 @@ class PromptGuardShield:
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
@ -117,10 +111,7 @@ class PromptGuardShield:
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
},
)
elif (
self.config.guard_type == PromptGuardType.jailbreak.value
and score_malicious > self.threshold
):
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_type=f"prompt_injection:malicious={score_malicious}",

View file

@ -54,15 +54,11 @@ class BasicScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list
@ -76,9 +72,7 @@ class BasicScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -108,12 +102,8 @@ class BasicScoringImpl(
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -32,9 +32,7 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
assert "generated_answer" in input_row, "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -18,7 +18,5 @@ equality = ScoringFn(
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -55,9 +55,7 @@ MULTILINGUAL_ANSWER_REGEXES = [
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
regex_parser_multiple_choice_answer = ScoringFn(
identifier="basic::regex_parser_multiple_choice_answer",
@ -66,10 +64,7 @@ regex_parser_multiple_choice_answer = ScoringFn(
provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -18,7 +18,5 @@ subset_of = ScoringFn(
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -33,17 +33,14 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert (
fn_def.params is not None
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
), f"RegexParserScoringFnParams not found for {fn_def}."
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -124,12 +124,10 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api
self.braintrust_evaluators = {
entry.identifier: entry.evaluator
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
async def initialize(self) -> None: ...
@ -139,16 +137,14 @@ class BraintrustScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"braintrust"
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
)
return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
async def set_api_key(self) -> None:
# api key is in the request headers
@ -171,17 +167,13 @@ class BraintrustScoringImpl(
await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
@ -222,13 +214,8 @@ class BraintrustScoringImpl(
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:

View file

@ -21,7 +21,5 @@ answer_correctness_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-correctness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ answer_relevancy_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ answer_similarity_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

Some files were not shown because too many files have changed in this diff Show more