forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -1,7 +1,8 @@
|
||||||
[flake8]
|
|
||||||
# Suggested config from pytorch that we can adapt
|
# Suggested config from pytorch that we can adapt
|
||||||
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
|
lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"]
|
||||||
max-line-length = 120
|
|
||||||
|
line-length = 120
|
||||||
|
|
||||||
# C408 ignored because we like the dict keyword argument syntax
|
# C408 ignored because we like the dict keyword argument syntax
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
# N812 ignored because import torch.nn.functional as F is PyTorch convention
|
# N812 ignored because import torch.nn.functional as F is PyTorch convention
|
||||||
|
@ -9,23 +10,28 @@ max-line-length = 120
|
||||||
# E731 allow usage of assigning lambda expressions
|
# E731 allow usage of assigning lambda expressions
|
||||||
# E701 let black auto-format statements on one line
|
# E701 let black auto-format statements on one line
|
||||||
# E704 let black auto-format statements on one line
|
# E704 let black auto-format statements on one line
|
||||||
ignore =
|
lint.ignore = [
|
||||||
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,E701,E704
|
"E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841",
|
||||||
|
"C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701",
|
||||||
|
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
|
||||||
|
"C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023",
|
||||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||||
# to line this up with executable bit
|
# to line this up with executable bit
|
||||||
EXE001,
|
"EXE001",
|
||||||
# random naming hints don't need
|
# random naming hints don't need
|
||||||
N802,
|
"N802",
|
||||||
# these ignores are from flake8-bugbear; please fix!
|
# these ignores are from flake8-bugbear; please fix!
|
||||||
B007,B008,B950
|
"B007", "B008"
|
||||||
optional-ascii-coding = True
|
]
|
||||||
exclude =
|
|
||||||
./.git,
|
exclude = [
|
||||||
./docs/*,
|
"./.git",
|
||||||
./build,
|
"./docs/*",
|
||||||
./scripts,
|
"./build",
|
||||||
./venv,
|
"./scripts",
|
||||||
*.pyi,
|
"./venv",
|
||||||
.pre-commit-config.yaml,
|
"*.pyi",
|
||||||
*.md,
|
".pre-commit-config.yaml",
|
||||||
.flake8
|
"*.md",
|
||||||
|
".flake8"
|
||||||
|
]
|
|
@ -77,7 +77,7 @@ agent_config = AgentConfig(
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
# Enable both RAG and tool usage
|
# Enable both RAG and tool usage
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}}.
|
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
# Configure safety
|
# Configure safety
|
||||||
|
@ -86,13 +86,9 @@ agent_config = AgentConfig(
|
||||||
# Control the inference loop
|
# Control the inference loop
|
||||||
max_infer_iters=5,
|
max_infer_iters=5,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"strategy": {
|
"strategy": {"type": "top_p", "temperature": 0.7, "top_p": 0.95},
|
||||||
"type": "top_p",
|
"max_tokens": 2048,
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.95
|
|
||||||
},
|
},
|
||||||
"max_tokens": 2048
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = Agent(client, agent_config)
|
agent = Agent(client, agent_config)
|
||||||
|
@ -101,11 +97,13 @@ session_id = agent.create_session("monitored_session")
|
||||||
# Stream the agent's execution steps
|
# Stream the agent's execution steps
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
||||||
attachments=[{
|
attachments=[
|
||||||
|
{
|
||||||
"content": "https://raw.githubusercontent.com/example/code.py",
|
"content": "https://raw.githubusercontent.com/example/code.py",
|
||||||
"mime_type": "text/plain"
|
"mime_type": "text/plain",
|
||||||
}],
|
}
|
||||||
session_id=session_id
|
],
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Monitor each step of execution
|
# Monitor each step of execution
|
||||||
|
|
|
@ -15,6 +15,7 @@ This first example walks you through how to evaluate a model candidate served by
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
|
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
|
||||||
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
|
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
|
||||||
eval_rows = ds.to_pandas().to_dict(orient="records")
|
eval_rows = ds.to_pandas().to_dict(orient="records")
|
||||||
|
@ -43,7 +44,7 @@ system_message = {
|
||||||
client.eval_tasks.register(
|
client.eval_tasks.register(
|
||||||
eval_task_id="meta-reference::mmmu",
|
eval_task_id="meta-reference::mmmu",
|
||||||
dataset_id=f"mmmu-{subset}-{split}",
|
dataset_id=f"mmmu-{subset}-{split}",
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
|
@ -62,9 +63,9 @@ response = client.eval.evaluate_rows(
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
"system_message": system_message
|
"system_message": system_message,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -88,7 +89,7 @@ _ = client.datasets.register(
|
||||||
"input_query": {"type": "string"},
|
"input_query": {"type": "string"},
|
||||||
"expected_answer": {"type": "string"},
|
"expected_answer": {"type": "string"},
|
||||||
"chat_completion_input": {"type": "chat_completion_input"},
|
"chat_completion_input": {"type": "chat_completion_input"},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_rows = client.datasetio.get_rows_paginated(
|
eval_rows = client.datasetio.get_rows_paginated(
|
||||||
|
@ -101,7 +102,7 @@ eval_rows = client.datasetio.get_rows_paginated(
|
||||||
client.eval_tasks.register(
|
client.eval_tasks.register(
|
||||||
eval_task_id="meta-reference::simpleqa",
|
eval_task_id="meta-reference::simpleqa",
|
||||||
dataset_id=simpleqa_dataset_id,
|
dataset_id=simpleqa_dataset_id,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"]
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
|
@ -120,8 +121,8 @@ response = client.eval.evaluate_rows(
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -144,14 +145,14 @@ agent_config = {
|
||||||
{
|
{
|
||||||
"type": "brave_search",
|
"type": "brave_search",
|
||||||
"engine": "tavily",
|
"engine": "tavily",
|
||||||
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
|
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
"tool_prompt_format": "json",
|
"tool_prompt_format": "json",
|
||||||
"input_shields": [],
|
"input_shields": [],
|
||||||
"output_shields": [],
|
"output_shields": [],
|
||||||
"enable_session_persistence": False
|
"enable_session_persistence": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
|
@ -163,7 +164,7 @@ response = client.eval.evaluate_rows(
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "agent",
|
"type": "agent",
|
||||||
"config": agent_config,
|
"config": agent_config,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
|
@ -13,7 +13,7 @@ Here's how to set up basic evaluation:
|
||||||
response = client.eval_tasks.register(
|
response = client.eval_tasks.register(
|
||||||
eval_task_id="my_eval",
|
eval_task_id="my_eval",
|
||||||
dataset_id="my_dataset",
|
dataset_id="my_dataset",
|
||||||
scoring_functions=["accuracy", "relevance"]
|
scoring_functions=["accuracy", "relevance"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run evaluation
|
# Run evaluation
|
||||||
|
@ -21,16 +21,10 @@ job = client.eval.run_eval(
|
||||||
task_id="my_eval",
|
task_id="my_eval",
|
||||||
task_config={
|
task_config={
|
||||||
"type": "app",
|
"type": "app",
|
||||||
"eval_candidate": {
|
"eval_candidate": {"type": "agent", "config": agent_config},
|
||||||
"type": "agent",
|
},
|
||||||
"config": agent_config
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get results
|
# Get results
|
||||||
result = client.eval.job_result(
|
result = client.eval.job_result(task_id="my_eval", job_id=job.job_id)
|
||||||
task_id="my_eval",
|
|
||||||
job_id=job.job_id
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -34,15 +34,16 @@ chunks = [
|
||||||
{
|
{
|
||||||
"document_id": "doc1",
|
"document_id": "doc1",
|
||||||
"content": "Your document text here",
|
"content": "Your document text here",
|
||||||
"mime_type": "text/plain"
|
"mime_type": "text/plain",
|
||||||
},
|
},
|
||||||
...
|
...,
|
||||||
]
|
]
|
||||||
client.vector_io.insert(vector_db_id, chunks)
|
client.vector_io.insert(vector_db_id, chunks)
|
||||||
|
|
||||||
# You can then query for these chunks
|
# You can then query for these chunks
|
||||||
chunks_response = client.vector_io.query(vector_db_id, query="What do you know about...")
|
chunks_response = client.vector_io.query(
|
||||||
|
vector_db_id, query="What do you know about..."
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using the RAG Tool
|
### Using the RAG Tool
|
||||||
|
@ -81,7 +82,6 @@ results = client.tool_runtime.rag_tool.query(
|
||||||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
# Configure agent with memory
|
# Configure agent with memory
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
|
@ -91,9 +91,9 @@ agent_config = AgentConfig(
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
],
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = Agent(client, agent_config)
|
agent = Agent(client, agent_config)
|
||||||
|
@ -101,25 +101,21 @@ session_id = agent.create_session("rag_session")
|
||||||
|
|
||||||
# Initial document ingestion
|
# Initial document ingestion
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{
|
messages=[
|
||||||
"role": "user",
|
{"role": "user", "content": "I am providing some documents for reference."}
|
||||||
"content": "I am providing some documents for reference."
|
],
|
||||||
}],
|
|
||||||
documents=[
|
documents=[
|
||||||
dict(
|
dict(
|
||||||
content="https://raw.githubusercontent.com/example/doc.rst",
|
content="https://raw.githubusercontent.com/example/doc.rst",
|
||||||
mime_type="text/plain"
|
mime_type="text/plain",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
session_id=session_id
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query with RAG
|
# Query with RAG
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{
|
messages=[{"role": "user", "content": "What are the key topics in the documents?"}],
|
||||||
"role": "user",
|
session_id=session_id,
|
||||||
"content": "What are the key topics in the documents?"
|
|
||||||
}],
|
|
||||||
session_id=session_id
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
|
@ -5,15 +5,11 @@ Safety is a critical component of any AI application. Llama Stack provides a Shi
|
||||||
```python
|
```python
|
||||||
# Register a safety shield
|
# Register a safety shield
|
||||||
shield_id = "content_safety"
|
shield_id = "content_safety"
|
||||||
client.shields.register(
|
client.shields.register(shield_id=shield_id, provider_shield_id="llama-guard-basic")
|
||||||
shield_id=shield_id,
|
|
||||||
provider_shield_id="llama-guard-basic"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run content through shield
|
# Run content through shield
|
||||||
response = client.safety.run_shield(
|
response = client.safety.run_shield(
|
||||||
shield_id=shield_id,
|
shield_id=shield_id, messages=[{"role": "user", "content": "User message here"}]
|
||||||
messages=[{"role": "user", "content": "User message here"}]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.violation:
|
if response.violation:
|
||||||
|
|
|
@ -8,24 +8,16 @@ The telemetry system supports three main types of events:
|
||||||
- **Unstructured Log Events**: Free-form log messages with severity levels
|
- **Unstructured Log Events**: Free-form log messages with severity levels
|
||||||
```python
|
```python
|
||||||
unstructured_log_event = UnstructuredLogEvent(
|
unstructured_log_event = UnstructuredLogEvent(
|
||||||
message="This is a log message",
|
message="This is a log message", severity=LogSeverity.INFO
|
||||||
severity=LogSeverity.INFO
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
- **Metric Events**: Numerical measurements with units
|
- **Metric Events**: Numerical measurements with units
|
||||||
```python
|
```python
|
||||||
metric_event = MetricEvent(
|
metric_event = MetricEvent(metric="my_metric", value=10, unit="count")
|
||||||
metric="my_metric",
|
|
||||||
value=10,
|
|
||||||
unit="count"
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
- **Structured Log Events**: System events like span start/end. Extensible to add more structured log types.
|
- **Structured Log Events**: System events like span start/end. Extensible to add more structured log types.
|
||||||
```python
|
```python
|
||||||
structured_log_event = SpanStartPayload(
|
structured_log_event = SpanStartPayload(name="my_span", parent_span_id="parent_span_id")
|
||||||
name="my_span",
|
|
||||||
parent_span_id="parent_span_id"
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Spans and Traces
|
### Spans and Traces
|
||||||
|
|
|
@ -35,7 +35,7 @@ Example client SDK call to register a "websearch" toolgroup that is provided by
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
provider_id="brave-search",
|
provider_id="brave-search",
|
||||||
args={"max_results": 5}
|
args={"max_results": 5},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -50,8 +50,7 @@ The Code Interpreter allows execution of Python code within a controlled environ
|
||||||
```python
|
```python
|
||||||
# Register Code Interpreter tool group
|
# Register Code Interpreter tool group
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::code_interpreter",
|
toolgroup_id="builtin::code_interpreter", provider_id="code_interpreter"
|
||||||
provider_id="code_interpreter"
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -68,16 +67,14 @@ The WolframAlpha tool provides access to computational knowledge through the Wol
|
||||||
```python
|
```python
|
||||||
# Register WolframAlpha tool group
|
# Register WolframAlpha tool group
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::wolfram_alpha",
|
toolgroup_id="builtin::wolfram_alpha", provider_id="wolfram-alpha"
|
||||||
provider_id="wolfram-alpha"
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```python
|
```python
|
||||||
result = client.tool_runtime.invoke_tool(
|
result = client.tool_runtime.invoke_tool(
|
||||||
tool_name="wolfram_alpha",
|
tool_name="wolfram_alpha", args={"query": "solve x^2 + 2x + 1 = 0"}
|
||||||
args={"query": "solve x^2 + 2x + 1 = 0"}
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -90,10 +87,7 @@ The Memory tool enables retrieval of context from various types of memory banks
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::memory",
|
toolgroup_id="builtin::memory",
|
||||||
provider_id="memory",
|
provider_id="memory",
|
||||||
args={
|
args={"max_chunks": 5, "max_tokens_in_context": 4096},
|
||||||
"max_chunks": 5,
|
|
||||||
"max_tokens_in_context": 4096
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -136,9 +130,7 @@ config = AgentConfig(
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
"builtin::websearch",
|
"builtin::websearch",
|
||||||
],
|
],
|
||||||
client_tools=[
|
client_tools=[ToolDef(name="client_tool", description="Client provided tool")],
|
||||||
ToolDef(name="client_tool", description="Client provided tool")
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -167,9 +159,9 @@ Example tool definition:
|
||||||
"name": "query",
|
"name": "query",
|
||||||
"parameter_type": "string",
|
"parameter_type": "string",
|
||||||
"description": "The query to search for",
|
"description": "The query to search for",
|
||||||
"required": True
|
"required": True,
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -179,8 +171,7 @@ Tools can be invoked using the `invoke_tool` method:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
result = client.tool_runtime.invoke_tool(
|
result = client.tool_runtime.invoke_tool(
|
||||||
tool_name="web_search",
|
tool_name="web_search", kwargs={"query": "What is the capital of France?"}
|
||||||
kwargs={"query": "What is the capital of France?"}
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Using Llama Stack as a Library
|
# Using Llama Stack as a Library
|
||||||
|
|
||||||
If you are planning to use an external service for Inference (even Ollama or TGI counts as external), it is often easier to use Llama Stack as a library. This avoids the overhead of setting up a server.
|
If you are planning to use an external service for Inference (even Ollama or TGI counts as external), it is often easier to use Llama Stack as a library. This avoids the overhead of setting up a server.
|
||||||
```python
|
```bash
|
||||||
# setup
|
# setup
|
||||||
pip install llama-stack
|
uv pip install llama-stack
|
||||||
llama stack build --template together --image-type venv
|
llama stack build --template together --image-type venv
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
client = LlamaStackAsLibraryClient(
|
client = LlamaStackAsLibraryClient(
|
||||||
"ollama",
|
"ollama",
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data = {"tavily_search_api_key": os.environ['TAVILY_SEARCH_API_KEY']}
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
await client.initialize()
|
await client.initialize()
|
||||||
```
|
```
|
||||||
|
|
|
@ -96,18 +96,26 @@ Here is a simple example to perform chat completions using the SDK.
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
def create_http_client():
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
return LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
|
|
||||||
|
return LlamaStackClient(
|
||||||
|
base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_library_client(template="ollama"):
|
def create_library_client(template="ollama"):
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient(template)
|
client = LlamaStackAsLibraryClient(template)
|
||||||
client.initialize()
|
client.initialize()
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
client = create_library_client() # or create_http_client() depending on the environment you picked
|
client = (
|
||||||
|
create_library_client()
|
||||||
|
) # or create_http_client() depending on the environment you picked
|
||||||
|
|
||||||
# List available models
|
# List available models
|
||||||
models = client.models.list()
|
models = client.models.list()
|
||||||
|
@ -120,8 +128,8 @@ response = client.inference.chat_completion(
|
||||||
model_id=os.environ["INFERENCE_MODEL"],
|
model_id=os.environ["INFERENCE_MODEL"],
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Write a haiku about coding"}
|
{"role": "user", "content": "Write a haiku about coding"},
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
print(response.completion_message.content)
|
print(response.completion_message.content)
|
||||||
```
|
```
|
||||||
|
@ -139,7 +147,9 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
from llama_stack_client.types import Document
|
from llama_stack_client.types import Document
|
||||||
|
|
||||||
client = create_library_client() # or create_http_client() depending on the environment you picked
|
client = (
|
||||||
|
create_library_client()
|
||||||
|
) # or create_http_client() depending on the environment you picked
|
||||||
|
|
||||||
# Documents to be used for RAG
|
# Documents to be used for RAG
|
||||||
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
|
||||||
|
@ -174,12 +184,12 @@ agent_config = AgentConfig(
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
# Define tools available to the agent
|
# Define tools available to the agent
|
||||||
toolgroups = [
|
toolgroups=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag",
|
||||||
"args" : {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -193,7 +203,7 @@ user_prompts = [
|
||||||
|
|
||||||
# Run the agent loop by calling the `create_turn` method
|
# Run the agent loop by calling the `create_turn` method
|
||||||
for prompt in user_prompts:
|
for prompt in user_prompts:
|
||||||
cprint(f'User> {prompt}', 'green')
|
cprint(f"User> {prompt}", "green")
|
||||||
response = rag_agent.create_turn(
|
response = rag_agent.create_turn(
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
|
|
@ -51,6 +51,7 @@ This first example walks you through how to evaluate a model candidate served by
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
|
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
|
||||||
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
|
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
|
||||||
eval_rows = ds.to_pandas().to_dict(orient="records")
|
eval_rows = ds.to_pandas().to_dict(orient="records")
|
||||||
|
@ -79,7 +80,7 @@ system_message = {
|
||||||
client.eval_tasks.register(
|
client.eval_tasks.register(
|
||||||
eval_task_id="meta-reference::mmmu",
|
eval_task_id="meta-reference::mmmu",
|
||||||
dataset_id=f"mmmu-{subset}-{split}",
|
dataset_id=f"mmmu-{subset}-{split}",
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
|
@ -98,9 +99,9 @@ response = client.eval.evaluate_rows(
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
"system_message": system_message
|
"system_message": system_message,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -124,7 +125,7 @@ _ = client.datasets.register(
|
||||||
"input_query": {"type": "string"},
|
"input_query": {"type": "string"},
|
||||||
"expected_answer": {"type": "string"},
|
"expected_answer": {"type": "string"},
|
||||||
"chat_completion_input": {"type": "chat_completion_input"},
|
"chat_completion_input": {"type": "chat_completion_input"},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_rows = client.datasetio.get_rows_paginated(
|
eval_rows = client.datasetio.get_rows_paginated(
|
||||||
|
@ -137,7 +138,7 @@ eval_rows = client.datasetio.get_rows_paginated(
|
||||||
client.eval_tasks.register(
|
client.eval_tasks.register(
|
||||||
eval_task_id="meta-reference::simpleqa",
|
eval_task_id="meta-reference::simpleqa",
|
||||||
dataset_id=simpleqa_dataset_id,
|
dataset_id=simpleqa_dataset_id,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"]
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
|
@ -156,8 +157,8 @@ response = client.eval.evaluate_rows(
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"repeat_penalty": 1.0,
|
"repeat_penalty": 1.0,
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -180,14 +181,14 @@ agent_config = {
|
||||||
{
|
{
|
||||||
"type": "brave_search",
|
"type": "brave_search",
|
||||||
"engine": "tavily",
|
"engine": "tavily",
|
||||||
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
|
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
"tool_prompt_format": "json",
|
"tool_prompt_format": "json",
|
||||||
"input_shields": [],
|
"input_shields": [],
|
||||||
"output_shields": [],
|
"output_shields": [],
|
||||||
"enable_session_persistence": False
|
"enable_session_persistence": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.eval.evaluate_rows(
|
response = client.eval.evaluate_rows(
|
||||||
|
@ -199,8 +200,8 @@ response = client.eval.evaluate_rows(
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "agent",
|
"type": "agent",
|
||||||
"config": agent_config,
|
"config": agent_config,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -237,7 +238,9 @@ GENERATED_RESPONSE: {generated_answer}
|
||||||
EXPECTED_RESPONSE: {expected_answer}
|
EXPECTED_RESPONSE: {expected_answer}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_query = "What are the top 5 topics that were explained? Only list succinct bullet points."
|
input_query = (
|
||||||
|
"What are the top 5 topics that were explained? Only list succinct bullet points."
|
||||||
|
)
|
||||||
generated_answer = """
|
generated_answer = """
|
||||||
Here are the top 5 topics that were explained in the documentation for Torchtune:
|
Here are the top 5 topics that were explained in the documentation for Torchtune:
|
||||||
|
|
||||||
|
@ -268,7 +271,9 @@ scoring_params = {
|
||||||
"braintrust::factuality": None,
|
"braintrust::factuality": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.scoring.score(input_rows=dataset_rows, scoring_functions=scoring_params)
|
response = client.scoring.score(
|
||||||
|
input_rows=dataset_rows, scoring_functions=scoring_params
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Running Evaluations via CLI
|
## Running Evaluations via CLI
|
||||||
|
|
|
@ -33,7 +33,11 @@ from llama_stack_client.types import (
|
||||||
Types:
|
Types:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.types import ListToolGroupsResponse, ToolGroup, ToolgroupListResponse
|
from llama_stack_client.types import (
|
||||||
|
ListToolGroupsResponse,
|
||||||
|
ToolGroup,
|
||||||
|
ToolgroupListResponse,
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
|
@ -444,7 +448,11 @@ Methods:
|
||||||
Types:
|
Types:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client.types import EvalTask, ListEvalTasksResponse, EvalTaskListResponse
|
from llama_stack_client.types import (
|
||||||
|
EvalTask,
|
||||||
|
ListEvalTasksResponse,
|
||||||
|
EvalTaskListResponse,
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
|
|
|
@ -224,7 +224,7 @@ client = LlamaStackClient(base_url="http://localhost:5001")
|
||||||
response = client.inference.chat_completion(
|
response = client.inference.chat_completion(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a friendly assistant."},
|
{"role": "system", "content": "You are a friendly assistant."},
|
||||||
{"role": "user", "content": "Write a two-sentence poem about llama."}
|
{"role": "user", "content": "Write a two-sentence poem about llama."},
|
||||||
],
|
],
|
||||||
model_id=INFERENCE_MODEL,
|
model_id=INFERENCE_MODEL,
|
||||||
)
|
)
|
||||||
|
|
|
@ -86,9 +86,7 @@ class ShieldCallStep(StepCommon):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryRetrievalStep(StepCommon):
|
class MemoryRetrievalStep(StepCommon):
|
||||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
||||||
StepType.memory_retrieval.value
|
|
||||||
)
|
|
||||||
vector_db_ids: str
|
vector_db_ids: str
|
||||||
inserted_context: InterleavedContent
|
inserted_context: InterleavedContent
|
||||||
|
|
||||||
|
@ -184,9 +182,7 @@ class AgentTurnResponseEventType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||||
AgentTurnResponseEventType.step_start.value
|
|
||||||
)
|
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
@ -194,9 +190,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
|
||||||
AgentTurnResponseEventType.step_complete.value
|
|
||||||
)
|
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
@ -206,9 +200,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
|
||||||
AgentTurnResponseEventType.step_progress.value
|
|
||||||
)
|
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
|
||||||
|
@ -217,17 +209,13 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
|
||||||
AgentTurnResponseEventType.turn_start.value
|
|
||||||
)
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
|
||||||
AgentTurnResponseEventType.turn_complete.value
|
|
||||||
)
|
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
|
@ -329,9 +317,7 @@ class Agents(Protocol):
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
|
|
||||||
)
|
|
||||||
async def get_agents_turn(
|
async def get_agents_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
|
|
@ -63,9 +63,7 @@ class EventLogger:
|
||||||
if isinstance(chunk, ToolResponseMessage):
|
if isinstance(chunk, ToolResponseMessage):
|
||||||
yield (
|
yield (
|
||||||
chunk,
|
chunk,
|
||||||
LogEvent(
|
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
|
||||||
role="CustomTool", content=chunk.content, color="grey"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -81,17 +79,12 @@ class EventLogger:
|
||||||
|
|
||||||
step_type = event.payload.step_type
|
step_type = event.payload.step_type
|
||||||
# handle safety
|
# handle safety
|
||||||
if (
|
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
|
||||||
step_type == StepType.shield_call
|
|
||||||
and event_type == EventType.step_complete.value
|
|
||||||
):
|
|
||||||
violation = event.payload.step_details.violation
|
violation = event.payload.step_details.violation
|
||||||
if not violation:
|
if not violation:
|
||||||
yield (
|
yield (
|
||||||
event,
|
event,
|
||||||
LogEvent(
|
LogEvent(role=step_type, content="No Violation", color="magenta"),
|
||||||
role=step_type, content="No Violation", color="magenta"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield (
|
yield (
|
||||||
|
@ -110,9 +103,7 @@ class EventLogger:
|
||||||
# TODO: Currently this event is never received
|
# TODO: Currently this event is never received
|
||||||
yield (
|
yield (
|
||||||
event,
|
event,
|
||||||
LogEvent(
|
LogEvent(role=step_type, content="", end="", color="yellow"),
|
||||||
role=step_type, content="", end="", color="yellow"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif event_type == EventType.step_progress.value:
|
elif event_type == EventType.step_progress.value:
|
||||||
# HACK: if previous was not step/event was not inference's step_progress
|
# HACK: if previous was not step/event was not inference's step_progress
|
||||||
|
@ -125,9 +116,7 @@ class EventLogger:
|
||||||
):
|
):
|
||||||
yield (
|
yield (
|
||||||
event,
|
event,
|
||||||
LogEvent(
|
LogEvent(role=step_type, content="", end="", color="yellow"),
|
||||||
role=step_type, content="", end="", color="yellow"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
delta = event.payload.delta
|
delta = event.payload.delta
|
||||||
|
@ -161,9 +150,7 @@ class EventLogger:
|
||||||
if event_type == EventType.step_complete.value:
|
if event_type == EventType.step_complete.value:
|
||||||
response = event.payload.step_details.model_response
|
response = event.payload.step_details.model_response
|
||||||
if response.tool_calls:
|
if response.tool_calls:
|
||||||
content = ToolUtils.encode_tool_call(
|
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
|
||||||
response.tool_calls[0], tool_prompt_format
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
content = response.content
|
content = response.content
|
||||||
yield (
|
yield (
|
||||||
|
@ -202,10 +189,7 @@ class EventLogger:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
|
||||||
step_type == StepType.memory_retrieval
|
|
||||||
and event_type == EventType.step_complete.value
|
|
||||||
):
|
|
||||||
details = event.payload.step_details
|
details = event.payload.step_details
|
||||||
inserted_context = interleaved_content_as_str(details.inserted_context)
|
inserted_context = interleaved_content_as_str(details.inserted_context)
|
||||||
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
|
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
|
||||||
|
|
|
@ -39,6 +39,4 @@ class DatasetIO(Protocol):
|
||||||
) -> PaginatedRowsResult: ...
|
) -> PaginatedRowsResult: ...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/rows", method="POST")
|
@webmethod(route="/datasetio/rows", method="POST")
|
||||||
async def append_rows(
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||||
self, dataset_id: str, rows: List[Dict[str, Any]]
|
|
||||||
) -> None: ...
|
|
||||||
|
|
|
@ -63,9 +63,7 @@ class AppEvalTaskConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
EvalTaskConfig = register_schema(
|
EvalTaskConfig = register_schema(
|
||||||
Annotated[
|
Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")],
|
||||||
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
|
|
||||||
],
|
|
||||||
name="EvalTaskConfig",
|
name="EvalTaskConfig",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -245,9 +245,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||||
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.json_schema.value] = (
|
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||||
ResponseFormatType.json_schema.value
|
|
||||||
)
|
|
||||||
json_schema: Dict[str, Any]
|
json_schema: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@ -406,9 +404,7 @@ class Inference(Protocol):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
"""Generate a chat completion for the given messages using the specified model.
|
"""Generate a chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
|
|
@ -89,9 +89,7 @@ class QATFinetuningConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = register_schema(
|
AlgorithmConfig = register_schema(
|
||||||
Annotated[
|
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
|
||||||
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
|
|
||||||
],
|
|
||||||
name="AlgorithmConfig",
|
name="AlgorithmConfig",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -204,14 +202,10 @@ class PostTraining(Protocol):
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET")
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||||
self, job_uuid: str
|
|
||||||
) -> Optional[PostTrainingJobStatusResponse]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||||
async def get_training_job_artifacts(
|
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
||||||
self, job_uuid: str
|
|
||||||
) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
|
||||||
|
|
|
@ -23,9 +23,7 @@ class ResourceType(Enum):
|
||||||
class Resource(BaseModel):
|
class Resource(BaseModel):
|
||||||
"""Base class for all Llama Stack resources"""
|
"""Base class for all Llama Stack resources"""
|
||||||
|
|
||||||
identifier: str = Field(
|
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
||||||
description="Unique identifier for this resource in llama stack"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_resource_id: str = Field(
|
provider_resource_id: str = Field(
|
||||||
description="Unique identifier for this resource in the provider",
|
description="Unique identifier for this resource in the provider",
|
||||||
|
@ -34,6 +32,4 @@ class Resource(BaseModel):
|
||||||
|
|
||||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||||
|
|
||||||
type: ResourceType = Field(
|
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)")
|
||||||
description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)"
|
|
||||||
)
|
|
||||||
|
|
|
@ -43,9 +43,7 @@ class AggregationFunctionType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
|
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||||
ScoringFnParamsType.llm_as_judge.value
|
|
||||||
)
|
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
judge_score_regexes: Optional[List[str]] = Field(
|
judge_score_regexes: Optional[List[str]] = Field(
|
||||||
|
@ -60,9 +58,7 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.regex_parser.value] = (
|
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||||
ScoringFnParamsType.regex_parser.value
|
|
||||||
)
|
|
||||||
parsing_regexes: Optional[List[str]] = Field(
|
parsing_regexes: Optional[List[str]] = Field(
|
||||||
description="Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
|
@ -112,9 +108,7 @@ class CommonScoringFnFields(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFn(CommonScoringFnFields, Resource):
|
class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
type: Literal[ResourceType.scoring_function.value] = (
|
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
||||||
ResourceType.scoring_function.value
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scoring_fn_id(self) -> str:
|
def scoring_fn_id(self) -> str:
|
||||||
|
@ -141,9 +135,7 @@ class ScoringFunctions(Protocol):
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
|
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
|
||||||
async def get_scoring_function(
|
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
||||||
self, scoring_fn_id: str, /
|
|
||||||
) -> Optional[ScoringFn]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions", method="POST")
|
@webmethod(route="/scoring-functions", method="POST")
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
|
|
|
@ -102,9 +102,7 @@ class StructuredLogType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanStartPayload(BaseModel):
|
class SpanStartPayload(BaseModel):
|
||||||
type: Literal[StructuredLogType.SPAN_START.value] = (
|
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
||||||
StructuredLogType.SPAN_START.value
|
|
||||||
)
|
|
||||||
name: str
|
name: str
|
||||||
parent_span_id: Optional[str] = None
|
parent_span_id: Optional[str] = None
|
||||||
|
|
||||||
|
@ -190,9 +188,7 @@ class QuerySpanTreeResponse(BaseModel):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/events", method="POST")
|
@webmethod(route="/telemetry/events", method="POST")
|
||||||
async def log_event(
|
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
|
||||||
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces", method="GET")
|
@webmethod(route="/telemetry/traces", method="GET")
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
|
|
|
@ -64,9 +64,7 @@ RAGQueryGeneratorConfig = register_schema(
|
||||||
class RAGQueryConfig(BaseModel):
|
class RAGQueryConfig(BaseModel):
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
# for memory bank retrieval.
|
# for memory bank retrieval.
|
||||||
query_generator_config: RAGQueryGeneratorConfig = Field(
|
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||||
default=DefaultRAGQueryGeneratorConfig()
|
|
||||||
)
|
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
|
|
||||||
|
|
|
@ -150,8 +150,6 @@ class ToolRuntime(Protocol):
|
||||||
) -> List[ToolDef]: ...
|
) -> List[ToolDef]: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||||
self, tool_name: str, kwargs: Dict[str, Any]
|
|
||||||
) -> ToolInvocationResult:
|
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -147,9 +147,7 @@ class ParallelDownloader:
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def retry_with_exponential_backoff(
|
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
|
||||||
self, task: DownloadTask, func, *args, **kwargs
|
|
||||||
):
|
|
||||||
last_exception = None
|
last_exception = None
|
||||||
for attempt in range(task.max_retries):
|
for attempt in range(task.max_retries):
|
||||||
try:
|
try:
|
||||||
|
@ -166,13 +164,9 @@ class ParallelDownloader:
|
||||||
continue
|
continue
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
||||||
async def get_file_info(
|
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||||
self, client: httpx.AsyncClient, task: DownloadTask
|
|
||||||
) -> None:
|
|
||||||
async def _get_info():
|
async def _get_info():
|
||||||
response = await client.head(
|
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
||||||
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -201,14 +195,10 @@ class ParallelDownloader:
|
||||||
return False
|
return False
|
||||||
return os.path.getsize(task.output_file) == task.total_size
|
return os.path.getsize(task.output_file) == task.total_size
|
||||||
|
|
||||||
async def download_chunk(
|
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
|
||||||
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
|
|
||||||
) -> None:
|
|
||||||
async def _download_chunk():
|
async def _download_chunk():
|
||||||
headers = {"Range": f"bytes={start}-{end}"}
|
headers = {"Range": f"bytes={start}-{end}"}
|
||||||
async with client.stream(
|
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
|
||||||
"GET", task.url, headers=headers, **self.client_options
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
with open(task.output_file, "ab") as file:
|
with open(task.output_file, "ab") as file:
|
||||||
|
@ -225,8 +215,7 @@ class ParallelDownloader:
|
||||||
await self.retry_with_exponential_backoff(task, _download_chunk)
|
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DownloadError(
|
raise DownloadError(
|
||||||
f"Failed to download chunk {start}-{end} after "
|
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
|
||||||
f"{task.max_retries} attempts: {str(e)}"
|
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
async def prepare_download(self, task: DownloadTask) -> None:
|
async def prepare_download(self, task: DownloadTask) -> None:
|
||||||
|
@ -244,9 +233,7 @@ class ParallelDownloader:
|
||||||
# Check if file is already downloaded
|
# Check if file is already downloaded
|
||||||
if os.path.exists(task.output_file):
|
if os.path.exists(task.output_file):
|
||||||
if self.verify_file_integrity(task):
|
if self.verify_file_integrity(task):
|
||||||
self.console.print(
|
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||||
f"[green]Already downloaded {task.output_file}[/green]"
|
|
||||||
)
|
|
||||||
self.progress.update(task.task_id, completed=task.total_size)
|
self.progress.update(task.task_id, completed=task.total_size)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -259,9 +246,7 @@ class ParallelDownloader:
|
||||||
|
|
||||||
current_pos = task.downloaded_size
|
current_pos = task.downloaded_size
|
||||||
while current_pos < task.total_size:
|
while current_pos < task.total_size:
|
||||||
chunk_end = min(
|
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
|
||||||
current_pos + chunk_size - 1, task.total_size - 1
|
|
||||||
)
|
|
||||||
chunks.append((current_pos, chunk_end))
|
chunks.append((current_pos, chunk_end))
|
||||||
current_pos = chunk_end + 1
|
current_pos = chunk_end + 1
|
||||||
|
|
||||||
|
@ -273,18 +258,12 @@ class ParallelDownloader:
|
||||||
raise DownloadError(f"Download failed: {str(e)}") from e
|
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.progress.update(
|
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||||
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
|
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||||
)
|
|
||||||
raise DownloadError(
|
|
||||||
f"Download failed for {task.output_file}: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||||
try:
|
try:
|
||||||
total_remaining_size = sum(
|
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||||
task.total_size - task.downloaded_size for task in tasks
|
|
||||||
)
|
|
||||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||||
free_space = shutil.disk_usage(dir_path).free
|
free_space = shutil.disk_usage(dir_path).free
|
||||||
|
|
||||||
|
@ -314,9 +293,7 @@ class ParallelDownloader:
|
||||||
with self.progress:
|
with self.progress:
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
desc = f"Downloading {Path(task.output_file).name}"
|
desc = f"Downloading {Path(task.output_file).name}"
|
||||||
task.task_id = self.progress.add_task(
|
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
||||||
desc, total=task.total_size, completed=task.downloaded_size
|
|
||||||
)
|
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||||
|
|
||||||
|
@ -332,9 +309,7 @@ class ParallelDownloader:
|
||||||
if failed_tasks:
|
if failed_tasks:
|
||||||
self.console.print("\n[red]Some downloads failed:[/red]")
|
self.console.print("\n[red]Some downloads failed:[/red]")
|
||||||
for task, error in failed_tasks:
|
for task, error in failed_tasks:
|
||||||
self.console.print(
|
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
|
||||||
f"[red]- {Path(task.output_file).name}: {error}[/red]"
|
|
||||||
)
|
|
||||||
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
||||||
|
|
||||||
|
|
||||||
|
@ -396,11 +371,7 @@ def _meta_download(
|
||||||
output_file = str(output_dir / f)
|
output_file = str(output_dir / f)
|
||||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||||
total_size = info.pth_size if "consolidated" in f else 0
|
total_size = info.pth_size if "consolidated" in f else 0
|
||||||
tasks.append(
|
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
|
||||||
DownloadTask(
|
|
||||||
url=url, output_file=output_file, total_size=total_size, max_retries=3
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize and run parallel downloader
|
# Initialize and run parallel downloader
|
||||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||||
|
@ -446,14 +417,10 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
if any(output_dir.iterdir()):
|
if any(output_dir.iterdir()):
|
||||||
console.print(
|
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
|
||||||
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
resp = input(
|
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
|
||||||
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
|
||||||
)
|
|
||||||
if resp.lower() in ["restart", "r"]:
|
if resp.lower() in ["restart", "r"]:
|
||||||
shutil.rmtree(output_dir)
|
shutil.rmtree(output_dir)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
@ -471,9 +438,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Initialize and run parallel downloader
|
# Initialize and run parallel downloader
|
||||||
downloader = ParallelDownloader(
|
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||||
max_concurrent_downloads=max_concurrent_downloads
|
|
||||||
)
|
|
||||||
asyncio.run(downloader.download_all(tasks))
|
asyncio.run(downloader.download_all(tasks))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,33 +47,20 @@ class ModelPromptFormat(Subcommand):
|
||||||
|
|
||||||
# Only Llama 3.1 and 3.2 are supported
|
# Only Llama 3.1 and 3.2 are supported
|
||||||
supported_model_ids = [
|
supported_model_ids = [
|
||||||
m
|
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||||
for m in CoreModelId
|
|
||||||
if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
|
||||||
]
|
]
|
||||||
model_str = "\n".join([m.value for m in supported_model_ids])
|
model_str = "\n".join([m.value for m in supported_model_ids])
|
||||||
try:
|
try:
|
||||||
model_id = CoreModelId(args.model_name)
|
model_id = CoreModelId(args.model_name)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self.parser.error(
|
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
||||||
f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_id not in supported_model_ids:
|
if model_id not in supported_model_ids:
|
||||||
self.parser.error(
|
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||||
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
|
|
||||||
)
|
|
||||||
|
|
||||||
llama_3_1_file = (
|
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
||||||
importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
||||||
)
|
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
|
||||||
llama_3_2_text_file = (
|
|
||||||
importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
|
||||||
)
|
|
||||||
llama_3_2_vision_file = (
|
|
||||||
importlib.resources.files("llama_models")
|
|
||||||
/ "llama3_2/vision_prompt_format.md"
|
|
||||||
)
|
|
||||||
if model_family(model_id) == ModelFamily.llama3_1:
|
if model_family(model_id) == ModelFamily.llama3_1:
|
||||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||||
content = f.open("r").read()
|
content = f.open("r").read()
|
||||||
|
|
|
@ -17,16 +17,12 @@ class PromptGuardModel(BaseModel):
|
||||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||||
|
|
||||||
model_id: str = "Prompt-Guard-86M"
|
model_id: str = "Prompt-Guard-86M"
|
||||||
description: str = (
|
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||||
"Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
|
||||||
)
|
|
||||||
is_featured: bool = False
|
is_featured: bool = False
|
||||||
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
|
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
|
||||||
max_seq_length: int = 2048
|
max_seq_length: int = 2048
|
||||||
is_instruct_model: bool = False
|
is_instruct_model: bool = False
|
||||||
quantization_format: CheckpointQuantizationFormat = (
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
CheckpointQuantizationFormat.bf16
|
|
||||||
)
|
|
||||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||||
recommended_sampling_params: Optional[SamplingParams] = None
|
recommended_sampling_params: Optional[SamplingParams] = None
|
||||||
|
|
||||||
|
|
|
@ -56,9 +56,7 @@ def available_templates_specs() -> Dict[str, BuildConfig]:
|
||||||
return template_specs
|
return template_specs
|
||||||
|
|
||||||
|
|
||||||
def run_stack_build_command(
|
def run_stack_build_command(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
|
||||||
parser: argparse.ArgumentParser, args: argparse.Namespace
|
|
||||||
) -> None:
|
|
||||||
if args.list_templates:
|
if args.list_templates:
|
||||||
return _run_template_list_cmd()
|
return _run_template_list_cmd()
|
||||||
|
|
||||||
|
@ -129,11 +127,7 @@ def run_stack_build_command(
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
available_providers = [
|
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
|
||||||
x
|
|
||||||
for x in providers_for_api.keys()
|
|
||||||
if x not in ("remote", "remote::sample")
|
|
||||||
]
|
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for API {}: ".format(api.value),
|
"> Enter provider for API {}: ".format(api.value),
|
||||||
completer=WordCompleter(available_providers),
|
completer=WordCompleter(available_providers),
|
||||||
|
@ -156,9 +150,7 @@ def run_stack_build_command(
|
||||||
description=description,
|
description=description,
|
||||||
)
|
)
|
||||||
|
|
||||||
build_config = BuildConfig(
|
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
||||||
image_type=image_type, distribution_spec=distribution_spec
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
with open(args.config, "r") as f:
|
with open(args.config, "r") as f:
|
||||||
try:
|
try:
|
||||||
|
@ -179,9 +171,7 @@ def run_stack_build_command(
|
||||||
|
|
||||||
if args.print_deps_only:
|
if args.print_deps_only:
|
||||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||||
normal_deps, special_deps = get_provider_dependencies(
|
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||||
build_config.distribution_spec.providers
|
|
||||||
)
|
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
print(f"uv pip install {' '.join(normal_deps)}")
|
print(f"uv pip install {' '.join(normal_deps)}")
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
|
@ -206,9 +196,7 @@ def _generate_run_config(
|
||||||
"""
|
"""
|
||||||
apis = list(build_config.distribution_spec.providers.keys())
|
apis = list(build_config.distribution_spec.providers.keys())
|
||||||
run_config = StackRunConfig(
|
run_config = StackRunConfig(
|
||||||
container_image=(
|
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
|
||||||
image_name if build_config.image_type == ImageType.container.value else None
|
|
||||||
),
|
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis=apis,
|
apis=apis,
|
||||||
providers={},
|
providers={},
|
||||||
|
@ -228,13 +216,9 @@ def _generate_run_config(
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
|
||||||
config_type = instantiate_class_type(
|
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||||
provider_registry[Api(api)][provider_type].config_class
|
|
||||||
)
|
|
||||||
if hasattr(config_type, "sample_run_config"):
|
if hasattr(config_type, "sample_run_config"):
|
||||||
config = config_type.sample_run_config(
|
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
|
||||||
__distro_dir__=f"distributions/{image_name}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
|
@ -269,9 +253,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name = f"distribution-{template_name}"
|
image_name = f"distribution-{template_name}"
|
||||||
else:
|
else:
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError(
|
raise ValueError("Please specify an image name when building a container image without a template")
|
||||||
"Please specify an image name when building a container image without a template"
|
|
||||||
)
|
|
||||||
elif build_config.image_type == ImageType.conda.value:
|
elif build_config.image_type == ImageType.conda.value:
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError("Please specify an image name when building a conda image")
|
raise ValueError("Please specify an image name when building a conda image")
|
||||||
|
@ -299,10 +281,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
|
|
||||||
if template_name:
|
if template_name:
|
||||||
# copy run.yaml from template to build_dir instead of generating it again
|
# copy run.yaml from template to build_dir instead of generating it again
|
||||||
template_path = (
|
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
|
||||||
importlib.resources.files("llama_stack")
|
|
||||||
/ f"templates/{template_name}/run.yaml"
|
|
||||||
)
|
|
||||||
with importlib.resources.as_file(template_path) as path:
|
with importlib.resources.as_file(template_path) as path:
|
||||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
|
@ -82,31 +82,21 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
if not config_file.exists() and not has_yaml_suffix:
|
if not config_file.exists() and not has_yaml_suffix:
|
||||||
# check if this is a template
|
# check if this is a template
|
||||||
config_file = (
|
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||||
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
|
||||||
)
|
|
||||||
if config_file.exists():
|
if config_file.exists():
|
||||||
template_name = args.config
|
template_name = args.config
|
||||||
|
|
||||||
if not config_file.exists() and not has_yaml_suffix:
|
if not config_file.exists() and not has_yaml_suffix:
|
||||||
# check if it's a build config saved to conda dir
|
# check if it's a build config saved to conda dir
|
||||||
config_file = Path(
|
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
|
||||||
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not config_file.exists() and not has_yaml_suffix:
|
if not config_file.exists() and not has_yaml_suffix:
|
||||||
# check if it's a build config saved to container dir
|
# check if it's a build config saved to container dir
|
||||||
config_file = Path(
|
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
|
||||||
BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not config_file.exists() and not has_yaml_suffix:
|
if not config_file.exists() and not has_yaml_suffix:
|
||||||
# check if it's a build config saved to ~/.llama dir
|
# check if it's a build config saved to ~/.llama dir
|
||||||
config_file = Path(
|
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||||
DISTRIBS_BASE_DIR
|
|
||||||
/ f"llamastack-{args.config}"
|
|
||||||
/ f"{args.config}-run.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
self.parser.error(
|
self.parser.error(
|
||||||
|
@ -119,15 +109,8 @@ class StackRun(Subcommand):
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
|
||||||
if config.container_image:
|
if config.container_image:
|
||||||
script = (
|
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
||||||
importlib.resources.files("llama_stack")
|
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||||
/ "distribution/start_container.sh"
|
|
||||||
)
|
|
||||||
image_name = (
|
|
||||||
f"distribution-{template_name}"
|
|
||||||
if template_name
|
|
||||||
else config.container_image
|
|
||||||
)
|
|
||||||
run_args = [script, image_name]
|
run_args = [script, image_name]
|
||||||
else:
|
else:
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
|
@ -145,11 +128,7 @@ class StackRun(Subcommand):
|
||||||
if env_name == "base":
|
if env_name == "base":
|
||||||
return os.environ.get("CONDA_PREFIX")
|
return os.environ.get("CONDA_PREFIX")
|
||||||
# Get conda environments info
|
# Get conda environments info
|
||||||
conda_env_info = json.loads(
|
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||||
subprocess.check_output(
|
|
||||||
["conda", "info", "--envs", "--json"]
|
|
||||||
).decode()
|
|
||||||
)
|
|
||||||
envs = conda_env_info["envs"]
|
envs = conda_env_info["envs"]
|
||||||
for envpath in envs:
|
for envpath in envs:
|
||||||
if envpath.endswith(env_name):
|
if envpath.endswith(env_name):
|
||||||
|
@ -173,10 +152,7 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
script = (
|
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
|
||||||
importlib.resources.files("llama_stack")
|
|
||||||
/ "distribution/start_conda_env.sh"
|
|
||||||
)
|
|
||||||
run_args = [
|
run_args = [
|
||||||
script,
|
script,
|
||||||
image_name,
|
image_name,
|
||||||
|
|
|
@ -22,11 +22,7 @@ def format_row(row, col_widths):
|
||||||
if line.strip() == "":
|
if line.strip() == "":
|
||||||
lines.append("")
|
lines.append("")
|
||||||
else:
|
else:
|
||||||
lines.extend(
|
lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
|
||||||
textwrap.wrap(
|
|
||||||
line, width, break_long_words=False, replace_whitespace=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]
|
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]
|
||||||
|
|
|
@ -41,9 +41,7 @@ def up_to_date_config():
|
||||||
- provider_id: provider1
|
- provider_id: provider1
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config: {{}}
|
config: {{}}
|
||||||
""".format(
|
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
|
||||||
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,9 +81,7 @@ def old_config():
|
||||||
telemetry:
|
telemetry:
|
||||||
provider_type: noop
|
provider_type: noop
|
||||||
config: {{}}
|
config: {{}}
|
||||||
""".format(
|
""".format(built_at=datetime.now().isoformat())
|
||||||
built_at=datetime.now().isoformat()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,10 +104,7 @@ def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||||
result = parse_and_maybe_upgrade_config(old_config)
|
result = parse_and_maybe_upgrade_config(old_config)
|
||||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
assert all(
|
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
|
||||||
api in result.providers
|
|
||||||
for api in ["inference", "safety", "memory", "telemetry"]
|
|
||||||
)
|
|
||||||
safety_provider = result.providers["safety"][0]
|
safety_provider = result.providers["safety"][0]
|
||||||
assert safety_provider.provider_type == "meta-reference"
|
assert safety_provider.provider_type == "meta-reference"
|
||||||
assert "llama_guard_shield" in safety_provider.config
|
assert "llama_guard_shield" in safety_provider.config
|
||||||
|
|
|
@ -72,9 +72,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||||
return checksums
|
return checksums
|
||||||
|
|
||||||
|
|
||||||
def verify_files(
|
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
|
||||||
model_dir: Path, checksums: Dict[str, str], console: Console
|
|
||||||
) -> List[VerificationResult]:
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
|
|
|
@ -58,22 +58,14 @@ def get_provider_dependencies(
|
||||||
for api_str, provider_or_providers in config_providers.items():
|
for api_str, provider_or_providers in config_providers.items():
|
||||||
providers_for_api = all_providers[Api(api_str)]
|
providers_for_api = all_providers[Api(api_str)]
|
||||||
|
|
||||||
providers = (
|
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||||
provider_or_providers
|
|
||||||
if isinstance(provider_or_providers, list)
|
|
||||||
else [provider_or_providers]
|
|
||||||
)
|
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||||
provider_type = (
|
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||||
provider if isinstance(provider, str) else provider.provider_type
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider_type not in providers_for_api:
|
if provider_type not in providers_for_api:
|
||||||
raise ValueError(
|
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
|
||||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_spec = providers_for_api[provider_type]
|
provider_spec = providers_for_api[provider_type]
|
||||||
deps.extend(provider_spec.pip_packages)
|
deps.extend(provider_spec.pip_packages)
|
||||||
|
@ -109,19 +101,13 @@ def build_image(
|
||||||
image_name: str,
|
image_name: str,
|
||||||
template_or_config: str,
|
template_or_config: str,
|
||||||
):
|
):
|
||||||
container_base = (
|
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||||
build_config.distribution_spec.container_image or "python:3.10-slim"
|
|
||||||
)
|
|
||||||
|
|
||||||
normal_deps, special_deps = get_provider_dependencies(
|
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||||
build_config.distribution_spec.providers
|
|
||||||
)
|
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
|
|
||||||
if build_config.image_type == ImageType.container.value:
|
if build_config.image_type == ImageType.container.value:
|
||||||
script = str(
|
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||||
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
|
|
||||||
)
|
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
template_or_config,
|
template_or_config,
|
||||||
|
@ -132,9 +118,7 @@ def build_image(
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
elif build_config.image_type == ImageType.conda.value:
|
elif build_config.image_type == ImageType.conda.value:
|
||||||
script = str(
|
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||||
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
|
|
||||||
)
|
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
str(image_name),
|
str(image_name),
|
||||||
|
@ -142,9 +126,7 @@ def build_image(
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
elif build_config.image_type == ImageType.venv.value:
|
elif build_config.image_type == ImageType.venv.value:
|
||||||
script = str(
|
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||||
importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
|
|
||||||
)
|
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
str(image_name),
|
str(image_name),
|
||||||
|
|
|
@ -68,9 +68,7 @@ def create_api_client_class(protocol) -> Type:
|
||||||
return_type = None
|
return_type = None
|
||||||
else:
|
else:
|
||||||
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
||||||
assert return_type, (
|
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||||
f"Could not extract return type for {sig.return_annotation}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||||
|
@ -87,9 +85,7 @@ def create_api_client_class(protocol) -> Type:
|
||||||
webmethod, sig = self.routes[method_name]
|
webmethod, sig = self.routes[method_name]
|
||||||
|
|
||||||
return_type = extract_async_iterator_type(sig.return_annotation)
|
return_type = extract_async_iterator_type(sig.return_annotation)
|
||||||
assert return_type, (
|
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||||
f"Could not extract return type for {sig.return_annotation}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||||
|
@ -204,9 +200,7 @@ async def example(model: str = None):
|
||||||
if not model:
|
if not model:
|
||||||
model = "Llama3.2-3B-Instruct"
|
model = "Llama3.2-3B-Instruct"
|
||||||
|
|
||||||
message = UserMessage(
|
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
|
||||||
content="hello world, write me a 2 sentence poem about the moon"
|
|
||||||
)
|
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
|
|
||||||
stream = True
|
stream = True
|
||||||
|
|
|
@ -26,9 +26,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(
|
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||||
registry: Dict[str, ProviderSpec], provider: Provider
|
|
||||||
) -> Provider:
|
|
||||||
provider_spec = registry[provider.provider_type]
|
provider_spec = registry[provider.provider_type]
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
try:
|
try:
|
||||||
|
@ -47,9 +45,7 @@ def configure_single_provider(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def configure_api_providers(
|
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
|
||||||
config: StackRunConfig, build_spec: DistributionSpec
|
|
||||||
) -> StackRunConfig:
|
|
||||||
is_nux = len(config.providers) == 0
|
is_nux = len(config.providers) == 0
|
||||||
|
|
||||||
if is_nux:
|
if is_nux:
|
||||||
|
@ -87,9 +83,7 @@ def configure_api_providers(
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for p in existing_providers:
|
for p in existing_providers:
|
||||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||||
updated_providers.append(
|
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
||||||
configure_single_provider(provider_registry[api], p)
|
|
||||||
)
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
else:
|
else:
|
||||||
# we are newly configuring this API
|
# we are newly configuring this API
|
||||||
|
@ -114,11 +108,7 @@ def configure_api_providers(
|
||||||
configure_single_provider(
|
configure_single_provider(
|
||||||
provider_registry[api],
|
provider_registry[api],
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=(
|
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
|
||||||
f"{provider_type}-{i:02d}"
|
|
||||||
if len(plist) > 1
|
|
||||||
else provider_type
|
|
||||||
),
|
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
config={},
|
config={},
|
||||||
),
|
),
|
||||||
|
@ -137,11 +127,7 @@ def upgrade_from_routing_table(
|
||||||
def get_providers(entries):
|
def get_providers(entries):
|
||||||
return [
|
return [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=(
|
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
|
||||||
f"{entry['provider_type']}-{i:02d}"
|
|
||||||
if len(entries) > 1
|
|
||||||
else entry["provider_type"]
|
|
||||||
),
|
|
||||||
provider_type=entry["provider_type"],
|
provider_type=entry["provider_type"],
|
||||||
config=entry["config"],
|
config=entry["config"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -163,9 +163,7 @@ a default SQLite store will be used.""",
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
||||||
distribution_spec: DistributionSpec = Field(
|
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
|
||||||
description="The distribution spec to build including API providers. "
|
|
||||||
)
|
|
||||||
image_type: str = Field(
|
image_type: str = Field(
|
||||||
default="conda",
|
default="conda",
|
||||||
description="Type of package to build (conda | container | venv)",
|
description="Type of package to build (conda | container | venv)",
|
||||||
|
|
|
@ -55,9 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
|
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> List[Api]:
|
||||||
routing_table_apis = set(
|
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
|
||||||
)
|
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -154,9 +154,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
def sync_generator():
|
def sync_generator():
|
||||||
try:
|
try:
|
||||||
async_stream = loop.run_until_complete(
|
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||||
self.async_client.request(*args, **kwargs)
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
chunk = loop.run_until_complete(async_stream.__anext__())
|
chunk = loop.run_until_complete(async_stream.__anext__())
|
||||||
yield chunk
|
yield chunk
|
||||||
|
@ -181,9 +179,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
# of our logs are intended for server-side usage
|
# of our logs are intended for server-side usage
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
os.environ["TELEMETRY_SINKS"] = ",".join(
|
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||||
sink for sink in current_sinks if sink != "console"
|
|
||||||
)
|
|
||||||
|
|
||||||
if config_path_or_template_name.endswith(".yaml"):
|
if config_path_or_template_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_template_name)
|
config_path = Path(config_path_or_template_name)
|
||||||
|
@ -202,9 +198,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
try:
|
try:
|
||||||
self.impls = await construct_stack(
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
self.config, self.custom_provider_registry
|
|
||||||
)
|
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, "red")
|
cprint(_e.msg, "red")
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -247,9 +241,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
func = getattr(impl, endpoint.name)
|
func = getattr(impl, endpoint.name)
|
||||||
if endpoint.method not in endpoint_impls:
|
if endpoint.method not in endpoint_impls:
|
||||||
endpoint_impls[endpoint.method] = {}
|
endpoint_impls[endpoint.method] = {}
|
||||||
endpoint_impls[endpoint.method][
|
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
|
||||||
_convert_path_to_regex(endpoint.route)
|
|
||||||
] = func
|
|
||||||
|
|
||||||
self.endpoint_impls = endpoint_impls
|
self.endpoint_impls = endpoint_impls
|
||||||
return True
|
return True
|
||||||
|
@ -266,9 +258,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
if self.provider_data:
|
if self.provider_data:
|
||||||
set_request_provider_data(
|
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
||||||
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
response = await self._call_streaming(
|
response = await self._call_streaming(
|
||||||
|
@ -408,9 +398,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _convert_body(
|
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
||||||
self, path: str, method: str, body: Optional[dict] = None
|
|
||||||
) -> dict:
|
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -425,7 +413,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
for param_name, param in sig.parameters.items():
|
for param_name, param in sig.parameters.items():
|
||||||
if param_name in body:
|
if param_name in body:
|
||||||
value = body.get(param_name)
|
value = body.get(param_name)
|
||||||
converted_body[param_name] = convert_to_pydantic(
|
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||||
param.annotation, value
|
|
||||||
)
|
|
||||||
return converted_body
|
return converted_body
|
||||||
|
|
|
@ -115,9 +115,7 @@ async def resolve_impls(
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
routing_table_apis = set(
|
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
|
||||||
)
|
|
||||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||||
|
|
||||||
providers_with_specs = {}
|
providers_with_specs = {}
|
||||||
|
@ -125,16 +123,12 @@ async def resolve_impls(
|
||||||
for api_str, providers in run_config.providers.items():
|
for api_str, providers in run_config.providers.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
if api in routing_table_apis:
|
if api in routing_table_apis:
|
||||||
raise ValueError(
|
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
|
||||||
)
|
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if provider.provider_type not in provider_registry[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
|
@ -145,9 +139,7 @@ async def resolve_impls(
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
)
|
)
|
||||||
p.deps__ = [a.value for a in p.api_dependencies] + [
|
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||||
a.value for a in p.optional_api_dependencies
|
|
||||||
]
|
|
||||||
spec = ProviderWithSpec(
|
spec = ProviderWithSpec(
|
||||||
spec=p,
|
spec=p,
|
||||||
**(provider.model_dump()),
|
**(provider.model_dump()),
|
||||||
|
@ -158,9 +150,7 @@ async def resolve_impls(
|
||||||
providers_with_specs[key] = specs
|
providers_with_specs[key] = specs
|
||||||
|
|
||||||
apis_to_serve = run_config.apis or set(
|
apis_to_serve = run_config.apis or set(
|
||||||
list(providers_with_specs.keys())
|
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||||
+ [x.value for x in routing_table_apis]
|
|
||||||
+ [x.value for x in router_apis]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for info in builtin_automatically_routed_apis():
|
for info in builtin_automatically_routed_apis():
|
||||||
|
@ -197,9 +187,7 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
sorted_providers = topological_sort(
|
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
||||||
{k: v.values() for k, v in providers_with_specs.items()}
|
|
||||||
)
|
|
||||||
apis = [x[1].spec.api for x in sorted_providers]
|
apis = [x[1].spec.api for x in sorted_providers]
|
||||||
sorted_providers.append(
|
sorted_providers.append(
|
||||||
(
|
(
|
||||||
|
@ -237,9 +225,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
inner_impls = {}
|
inner_impls = {}
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
f"inner-{provider.spec.router_api.value}"
|
|
||||||
]
|
|
||||||
|
|
||||||
impl = await instantiate_provider(
|
impl = await instantiate_provider(
|
||||||
provider,
|
provider,
|
||||||
|
@ -336,10 +322,7 @@ async def instantiate_provider(
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
if (
|
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
||||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
|
||||||
and provider_spec.api in additional_protocols
|
|
||||||
):
|
|
||||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||||
check_protocol_compliance(impl, additional_api)
|
check_protocol_compliance(impl, additional_api)
|
||||||
|
|
||||||
|
@ -367,19 +350,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
log.error(
|
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
|
||||||
)
|
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method is actually implemented in the class
|
||||||
method_owner = next(
|
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
||||||
(cls for cls in mro if name in cls.__dict__), None
|
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
||||||
)
|
|
||||||
if (
|
|
||||||
method_owner is None
|
|
||||||
or method_owner.__name__ == protocol.__name__
|
|
||||||
):
|
|
||||||
missing_methods.append((name, "not_actually_implemented"))
|
missing_methods.append((name, "not_actually_implemented"))
|
||||||
|
|
||||||
if missing_methods:
|
if missing_methods:
|
||||||
|
|
|
@ -85,9 +85,7 @@ class VectorIORouter(VectorIO):
|
||||||
chunks: List[Chunk],
|
chunks: List[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
vector_db_id, chunks, ttl_seconds
|
|
||||||
)
|
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
|
@ -95,9 +93,7 @@ class VectorIORouter(VectorIO):
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
vector_db_id, query, params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
@ -123,9 +119,7 @@ class InferenceRouter(Inference):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
model_id, provider_model_id, provider_id, metadata, model_type
|
|
||||||
)
|
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -143,9 +137,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -176,9 +168,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -202,9 +192,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
@ -231,9 +219,7 @@ class SafetyRouter(Safety):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
return await self.routing_table.register_shield(
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
shield_id, provider_shield_id, provider_id, params
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -268,9 +254,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
page_token: Optional[str] = None,
|
page_token: Optional[str] = None,
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult:
|
) -> PaginatedRowsResult:
|
||||||
return await self.routing_table.get_provider_impl(
|
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||||
dataset_id
|
|
||||||
).get_rows_paginated(
|
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=rows_in_page,
|
rows_in_page=rows_in_page,
|
||||||
page_token=page_token,
|
page_token=page_token,
|
||||||
|
@ -305,9 +289,7 @@ class ScoringRouter(Scoring):
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
fn_identifier
|
|
||||||
).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -328,9 +310,7 @@ class ScoringRouter(Scoring):
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||||
fn_identifier
|
|
||||||
).score(
|
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -381,9 +361,7 @@ class EvalRouter(Eval):
|
||||||
task_id: str,
|
task_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
return await self.routing_table.get_provider_impl(task_id).job_status(
|
return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
|
||||||
task_id, job_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
self,
|
self,
|
||||||
|
@ -420,9 +398,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: List[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: Optional[RAGQueryConfig] = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
return await self.routing_table.get_provider_impl(
|
return await self.routing_table.get_provider_impl("query_from_memory").query(
|
||||||
"query_from_memory"
|
content, vector_db_ids, query_config
|
||||||
).query(content, vector_db_ids, query_config)
|
)
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
|
@ -430,9 +408,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
return await self.routing_table.get_provider_impl(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
"insert_into_memory"
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -460,6 +438,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
tool_group_id, mcp_endpoint
|
|
||||||
)
|
|
||||||
|
|
|
@ -94,9 +94,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(
|
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
|
||||||
) -> None:
|
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
|
@ -131,9 +129,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(
|
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||||
self, routing_key: str, provider_id: Optional[str] = None
|
|
||||||
) -> Any:
|
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
|
@ -171,9 +167,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
async def get_object_by_identifier(
|
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||||
self, type: str, identifier: str
|
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
|
||||||
# Get from disk registry
|
# Get from disk registry
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
if not obj:
|
if not obj:
|
||||||
|
@ -183,13 +177,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
await unregister_object_from_provider(
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
obj, self.impls_by_provider_id[obj.provider_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def register_object(
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
self, obj: RoutableObjectWithProvider
|
|
||||||
) -> RoutableObjectWithProvider:
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
@ -244,9 +234,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
"Embedding model must have an embedding dimension in its metadata"
|
|
||||||
)
|
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
|
@ -266,9 +254,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
return ListShieldsResponse(
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||||
data=await self.get_all_with_type(ResourceType.shield.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||||
return await self.get_object_by_identifier("shield", identifier)
|
return await self.get_object_by_identifier("shield", identifier)
|
||||||
|
@ -340,9 +326,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
f"Model {embedding_model} does not have an embedding dimension"
|
|
||||||
)
|
|
||||||
vector_db_data = {
|
vector_db_data = {
|
||||||
"identifier": vector_db_id,
|
"identifier": vector_db_id,
|
||||||
"type": ResourceType.vector_db.value,
|
"type": ResourceType.vector_db.value,
|
||||||
|
@ -364,9 +348,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
return ListDatasetsResponse(
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||||
data=await self.get_all_with_type(ResourceType.dataset.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
@ -411,9 +393,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
return ListScoringFunctionsResponse(
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||||
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
@ -510,12 +490,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
toolgroup_id, mcp_endpoint
|
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
)
|
|
||||||
tool_host = (
|
|
||||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
|
|
|
@ -43,9 +43,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
if api == Api.tool_runtime:
|
if api == Api.tool_runtime:
|
||||||
for tool_group in SpecialToolGroup:
|
for tool_group in SpecialToolGroup:
|
||||||
sub_protocol = toolgroup_protocols[tool_group]
|
sub_protocol = toolgroup_protocols[tool_group]
|
||||||
sub_protocol_methods = inspect.getmembers(
|
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
||||||
sub_protocol, predicate=inspect.isfunction
|
|
||||||
)
|
|
||||||
for name, method in sub_protocol_methods:
|
for name, method in sub_protocol_methods:
|
||||||
if not hasattr(method, "__webmethod__"):
|
if not hasattr(method, "__webmethod__"):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -76,9 +76,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||||
traceback.print_exception(exc)
|
traceback.print_exception(exc)
|
||||||
http_exc = translate_exception(exc)
|
http_exc = translate_exception(exc)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||||
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||||
|
@ -178,9 +176,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
return StreamingResponse(
|
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
return await maybe_await(value)
|
return await maybe_await(value)
|
||||||
|
@ -190,11 +186,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
new_params = [
|
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
|
||||||
inspect.Parameter(
|
|
||||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
|
||||||
)
|
|
||||||
]
|
|
||||||
new_params.extend(sig.parameters.values())
|
new_params.extend(sig.parameters.values())
|
||||||
|
|
||||||
path_params = extract_path_params(route)
|
path_params = extract_path_params(route)
|
||||||
|
@ -202,15 +194,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
||||||
new_params = [new_params[0]] + [
|
new_params = [new_params[0]] + [
|
||||||
(
|
(
|
||||||
param.replace(
|
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||||
annotation=Annotated[
|
|
||||||
param.annotation, FastapiPath(..., title=param.name)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if param.name in path_params
|
if param.name in path_params
|
||||||
else param.replace(
|
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
for param in new_params[1:]
|
for param in new_params[1:]
|
||||||
]
|
]
|
||||||
|
@ -244,12 +230,8 @@ class ClientVersionMiddleware:
|
||||||
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||||
if client_version:
|
if client_version:
|
||||||
try:
|
try:
|
||||||
client_version_parts = tuple(
|
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
|
||||||
map(int, client_version.split(".")[:2])
|
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
|
||||||
)
|
|
||||||
server_version_parts = tuple(
|
|
||||||
map(int, self.server_version.split(".")[:2])
|
|
||||||
)
|
|
||||||
if client_version_parts != server_version_parts:
|
if client_version_parts != server_version_parts:
|
||||||
|
|
||||||
async def send_version_error(send):
|
async def send_version_error(send):
|
||||||
|
@ -267,9 +249,7 @@ class ClientVersionMiddleware:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).encode()
|
).encode()
|
||||||
await send(
|
await send({"type": "http.response.body", "body": error_msg})
|
||||||
{"type": "http.response.body", "body": error_msg}
|
|
||||||
)
|
|
||||||
|
|
||||||
return await send_version_error(send)
|
return await send_version_error(send)
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
|
@ -296,9 +276,7 @@ def main():
|
||||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||||
help="Port to listen on",
|
help="Port to listen on",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
|
||||||
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--env",
|
"--env",
|
||||||
action="append",
|
action="append",
|
||||||
|
@ -323,9 +301,7 @@ def main():
|
||||||
raise ValueError(f"Config file {config_file} does not exist")
|
raise ValueError(f"Config file {config_file} does not exist")
|
||||||
print(f"Using config file: {config_file}")
|
print(f"Using config file: {config_file}")
|
||||||
elif args.template:
|
elif args.template:
|
||||||
config_file = (
|
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||||
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
|
||||||
)
|
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
print(f"Using template {args.template} config file: {config_file}")
|
print(f"Using template {args.template} config file: {config_file}")
|
||||||
|
@ -383,9 +359,7 @@ def main():
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
"ignore", category=UserWarning, module="pydantic._internal._fields"
|
|
||||||
)
|
|
||||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
|
@ -416,9 +390,7 @@ def main():
|
||||||
|
|
||||||
def extract_path_params(route: str) -> List[str]:
|
def extract_path_params(route: str) -> List[str]:
|
||||||
segments = route.split("/")
|
segments = route.split("/")
|
||||||
params = [
|
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||||
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
|
|
||||||
]
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -110,9 +110,7 @@ class EnvVarError(Exception):
|
||||||
def __init__(self, var_name: str, path: str = ""):
|
def __init__(self, var_name: str, path: str = ""):
|
||||||
self.var_name = var_name
|
self.var_name = var_name
|
||||||
self.path = path
|
self.path = path
|
||||||
super().__init__(
|
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||||
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
@ -187,9 +185,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
if not key:
|
if not key:
|
||||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
||||||
if not all(c.isalnum() or c == "_" for c in key):
|
if not all(c.isalnum() or c == "_" for c in key):
|
||||||
raise ValueError(
|
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
||||||
f"Key must contain only alphanumeric characters and underscores: {key}"
|
|
||||||
)
|
|
||||||
return key, value
|
return key, value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -202,20 +198,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
async def construct_stack(
|
async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
run_config.metadata_store, run_config.image_name
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
||||||
)
|
|
||||||
impls = await resolve_impls(
|
|
||||||
run_config, provider_registry or get_provider_registry(), dist_registry
|
|
||||||
)
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||||
template_path = (
|
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||||
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
with importlib.resources.as_file(template_path) as path:
|
with importlib.resources.as_file(template_path) as path:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
|
|
@ -25,9 +25,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||||
|
|
||||||
async def update(
|
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
||||||
self, obj: RoutableObjectWithProvider
|
|
||||||
) -> RoutableObjectWithProvider: ...
|
|
||||||
|
|
||||||
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||||
|
|
||||||
|
@ -61,9 +59,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_cached(
|
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||||
self, type: str, identifier: str
|
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
|
||||||
# Disk registry does not have a cache
|
# Disk registry does not have a cache
|
||||||
raise NotImplementedError("Disk registry does not have a cache")
|
raise NotImplementedError("Disk registry does not have a cache")
|
||||||
|
|
||||||
|
@ -72,12 +68,8 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
values = await self.kvstore.range(start_key, end_key)
|
values = await self.kvstore.range(start_key, end_key)
|
||||||
return _parse_registry_values(values)
|
return _parse_registry_values(values)
|
||||||
|
|
||||||
async def get(
|
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||||
self, type: str, identifier: str
|
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
|
||||||
json_str = await self.kvstore.get(
|
|
||||||
KEY_FORMAT.format(type=type, identifier=identifier)
|
|
||||||
)
|
|
||||||
if not json_str:
|
if not json_str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -143,9 +135,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
|
|
||||||
def get_cached(
|
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||||
self, type: str, identifier: str
|
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
|
||||||
return self.cache.get((type, identifier), None)
|
return self.cache.get((type, identifier), None)
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
|
@ -153,9 +143,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
async with self._locked_cache() as cache:
|
async with self._locked_cache() as cache:
|
||||||
return list(cache.values())
|
return list(cache.values())
|
||||||
|
|
||||||
async def get(
|
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||||
self, type: str, identifier: str
|
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
cache_key = (type, identifier)
|
cache_key = (type, identifier)
|
||||||
|
|
||||||
|
@ -197,9 +185,7 @@ async def create_dist_registry(
|
||||||
dist_kvstore = await kvstore_impl(metadata_store)
|
dist_kvstore = await kvstore_impl(metadata_store)
|
||||||
else:
|
else:
|
||||||
dist_kvstore = await kvstore_impl(
|
dist_kvstore = await kvstore_impl(
|
||||||
SqliteKVStoreConfig(
|
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
|
||||||
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||||
await dist_registry.initialize()
|
await dist_registry.initialize()
|
||||||
|
|
|
@ -161,9 +161,7 @@ async def test_duplicate_provider_registration(config):
|
||||||
|
|
||||||
result = await cached_registry.get("vector_db", "test_vector_db_2")
|
result = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert (
|
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||||
result.embedding_model == original_vector_db.embedding_model
|
|
||||||
) # Original values preserved
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -193,14 +191,9 @@ async def test_get_all_objects(config):
|
||||||
|
|
||||||
# Verify each vector_db was stored correctly
|
# Verify each vector_db was stored correctly
|
||||||
for original_vector_db in test_vector_dbs:
|
for original_vector_db in test_vector_dbs:
|
||||||
matching_vector_dbs = [
|
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||||
v for v in all_results if v.identifier == original_vector_db.identifier
|
|
||||||
]
|
|
||||||
assert len(matching_vector_dbs) == 1
|
assert len(matching_vector_dbs) == 1
|
||||||
stored_vector_db = matching_vector_dbs[0]
|
stored_vector_db = matching_vector_dbs[0]
|
||||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||||
assert (
|
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||||
stored_vector_db.embedding_dimension
|
|
||||||
== original_vector_db.embedding_dimension
|
|
||||||
)
|
|
||||||
|
|
|
@ -22,15 +22,11 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Playground pages
|
# Playground pages
|
||||||
chat_page = st.Page(
|
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||||
"page/playground/chat.py", title="Chat", icon="💬", default=True
|
|
||||||
)
|
|
||||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||||
|
|
||||||
# Distribution pages
|
# Distribution pages
|
||||||
resources_page = st.Page(
|
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||||
"page/distribution/resources.py", title="Resources", icon="🔍", default=False
|
|
||||||
)
|
|
||||||
provider_page = st.Page(
|
provider_page = st.Page(
|
||||||
"page/distribution/providers.py",
|
"page/distribution/providers.py",
|
||||||
title="API Providers",
|
title="API Providers",
|
||||||
|
|
|
@ -23,15 +23,11 @@ class LlamaStackApi:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_scoring(
|
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
|
||||||
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
|
|
||||||
):
|
|
||||||
"""Run scoring on a single row"""
|
"""Run scoring on a single row"""
|
||||||
if not scoring_params:
|
if not scoring_params:
|
||||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||||
return self.client.scoring.score(
|
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
||||||
input_rows=[row], scoring_functions=scoring_params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
llama_stack_api = LlamaStackApi()
|
llama_stack_api = LlamaStackApi()
|
||||||
|
|
|
@ -11,9 +11,7 @@ from modules.api import llama_stack_api
|
||||||
def datasets():
|
def datasets():
|
||||||
st.header("Datasets")
|
st.header("Datasets")
|
||||||
|
|
||||||
datasets_info = {
|
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
|
||||||
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
|
|
||||||
}
|
|
||||||
if len(datasets_info) > 0:
|
if len(datasets_info) > 0:
|
||||||
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
||||||
st.json(datasets_info[selected_dataset], expanded=True)
|
st.json(datasets_info[selected_dataset], expanded=True)
|
||||||
|
|
|
@ -12,12 +12,8 @@ def eval_tasks():
|
||||||
# Eval Tasks Section
|
# Eval Tasks Section
|
||||||
st.header("Eval Tasks")
|
st.header("Eval Tasks")
|
||||||
|
|
||||||
eval_tasks_info = {
|
eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
|
||||||
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(eval_tasks_info) > 0:
|
if len(eval_tasks_info) > 0:
|
||||||
selected_eval_task = st.selectbox(
|
selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
|
||||||
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
|
|
||||||
)
|
|
||||||
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
||||||
|
|
|
@ -11,9 +11,7 @@ from modules.api import llama_stack_api
|
||||||
def models():
|
def models():
|
||||||
# Models Section
|
# Models Section
|
||||||
st.header("Models")
|
st.header("Models")
|
||||||
models_info = {
|
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
|
||||||
m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()
|
|
||||||
}
|
|
||||||
|
|
||||||
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
||||||
st.json(models_info[selected_model])
|
st.json(models_info[selected_model])
|
||||||
|
|
|
@ -11,12 +11,7 @@ from modules.api import llama_stack_api
|
||||||
def scoring_functions():
|
def scoring_functions():
|
||||||
st.header("Scoring Functions")
|
st.header("Scoring Functions")
|
||||||
|
|
||||||
scoring_functions_info = {
|
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
|
||||||
s.identifier: s.to_dict()
|
|
||||||
for s in llama_stack_api.client.scoring_functions.list()
|
|
||||||
}
|
|
||||||
|
|
||||||
selected_scoring_function = st.selectbox(
|
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
|
||||||
"Select a scoring function", list(scoring_functions_info.keys())
|
|
||||||
)
|
|
||||||
st.json(scoring_functions_info[selected_scoring_function], expanded=True)
|
st.json(scoring_functions_info[selected_scoring_function], expanded=True)
|
||||||
|
|
|
@ -12,9 +12,7 @@ def shields():
|
||||||
# Shields Section
|
# Shields Section
|
||||||
st.header("Shields")
|
st.header("Shields")
|
||||||
|
|
||||||
shields_info = {
|
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
|
||||||
s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()
|
|
||||||
}
|
|
||||||
|
|
||||||
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
|
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
|
||||||
st.json(shields_info[selected_shield])
|
st.json(shields_info[selected_shield])
|
||||||
|
|
|
@ -10,14 +10,10 @@ from modules.api import llama_stack_api
|
||||||
|
|
||||||
def vector_dbs():
|
def vector_dbs():
|
||||||
st.header("Vector Databases")
|
st.header("Vector Databases")
|
||||||
vector_dbs_info = {
|
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
|
||||||
v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vector_dbs_info) > 0:
|
if len(vector_dbs_info) > 0:
|
||||||
selected_vector_db = st.selectbox(
|
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
|
||||||
"Select a vector database", list(vector_dbs_info.keys())
|
|
||||||
)
|
|
||||||
st.json(vector_dbs_info[selected_vector_db])
|
st.json(vector_dbs_info[selected_vector_db])
|
||||||
else:
|
else:
|
||||||
st.info("No vector databases found")
|
st.info("No vector databases found")
|
||||||
|
|
|
@ -14,7 +14,6 @@ from modules.utils import process_dataset
|
||||||
|
|
||||||
|
|
||||||
def application_evaluation_page():
|
def application_evaluation_page():
|
||||||
|
|
||||||
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
|
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
|
||||||
st.title("📊 Evaluations (Scoring)")
|
st.title("📊 Evaluations (Scoring)")
|
||||||
|
|
||||||
|
@ -83,9 +82,7 @@ def application_evaluation_page():
|
||||||
try:
|
try:
|
||||||
new_params[param_name] = json.loads(value)
|
new_params[param_name] = json.loads(value)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
st.error(
|
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
|
||||||
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
st.json(new_params)
|
st.json(new_params)
|
||||||
scoring_params[scoring_fn_id] = new_params
|
scoring_params[scoring_fn_id] = new_params
|
||||||
|
@ -128,9 +125,7 @@ def application_evaluation_page():
|
||||||
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
|
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
|
||||||
|
|
||||||
# Display current row results using separate containers
|
# Display current row results using separate containers
|
||||||
progress_text_container.write(
|
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
|
||||||
)
|
|
||||||
results_container.json(
|
results_container.json(
|
||||||
score_res.to_json(),
|
score_res.to_json(),
|
||||||
expanded=2,
|
expanded=2,
|
||||||
|
|
|
@ -195,7 +195,6 @@ def run_evaluation_3():
|
||||||
|
|
||||||
# Add run button and handle evaluation
|
# Add run button and handle evaluation
|
||||||
if st.button("Run Evaluation"):
|
if st.button("Run Evaluation"):
|
||||||
|
|
||||||
progress_text = "Running evaluation..."
|
progress_text = "Running evaluation..."
|
||||||
progress_bar = st.progress(0, text=progress_text)
|
progress_bar = st.progress(0, text=progress_text)
|
||||||
rows = rows.rows
|
rows = rows.rows
|
||||||
|
@ -233,9 +232,7 @@ def run_evaluation_3():
|
||||||
output_res[scoring_fn] = []
|
output_res[scoring_fn] = []
|
||||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||||
|
|
||||||
progress_text_container.write(
|
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
|
||||||
)
|
|
||||||
results_container.json(eval_res, expanded=2)
|
results_container.json(eval_res, expanded=2)
|
||||||
|
|
||||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||||
|
@ -247,7 +244,6 @@ def run_evaluation_3():
|
||||||
|
|
||||||
|
|
||||||
def native_evaluation_page():
|
def native_evaluation_page():
|
||||||
|
|
||||||
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
||||||
st.title("📊 Evaluations (Generation + Scoring)")
|
st.title("📊 Evaluations (Generation + Scoring)")
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,7 @@ from modules.api import llama_stack_api
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.header("Configuration")
|
st.header("Configuration")
|
||||||
available_models = llama_stack_api.client.models.list()
|
available_models = llama_stack_api.client.models.list()
|
||||||
available_models = [
|
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||||
model.identifier for model in available_models if model.model_type == "llm"
|
|
||||||
]
|
|
||||||
selected_model = st.selectbox(
|
selected_model = st.selectbox(
|
||||||
"Choose a model",
|
"Choose a model",
|
||||||
available_models,
|
available_models,
|
||||||
|
@ -128,6 +126,4 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
||||||
full_response = response
|
full_response = response
|
||||||
message_placeholder.markdown(full_response.completion_message.content)
|
message_placeholder.markdown(full_response.completion_message.content)
|
||||||
|
|
||||||
st.session_state.messages.append(
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||||
{"role": "assistant", "content": full_response}
|
|
||||||
)
|
|
||||||
|
|
|
@ -74,9 +74,7 @@ def rag_chat_page():
|
||||||
)
|
)
|
||||||
|
|
||||||
available_models = llama_stack_api.client.models.list()
|
available_models = llama_stack_api.client.models.list()
|
||||||
available_models = [
|
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||||
model.identifier for model in available_models if model.model_type == "llm"
|
|
||||||
]
|
|
||||||
selected_model = st.selectbox(
|
selected_model = st.selectbox(
|
||||||
"Choose a model",
|
"Choose a model",
|
||||||
available_models,
|
available_models,
|
||||||
|
@ -137,9 +135,7 @@ def rag_chat_page():
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name="builtin::rag",
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": [
|
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||||
vector_db_id for vector_db_id in selected_vector_dbs
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -186,9 +182,7 @@ def rag_chat_page():
|
||||||
message_placeholder.markdown(full_response + "▌")
|
message_placeholder.markdown(full_response + "▌")
|
||||||
message_placeholder.markdown(full_response)
|
message_placeholder.markdown(full_response)
|
||||||
|
|
||||||
st.session_state.messages.append(
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||||
{"role": "assistant", "content": full_response}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
rag_chat_page()
|
rag_chat_page()
|
||||||
|
|
|
@ -8,9 +8,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
LLAMA_STACK_CONFIG_DIR = Path(
|
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
||||||
os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))
|
|
||||||
)
|
|
||||||
|
|
||||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||||
|
|
||||||
|
|
|
@ -31,15 +31,11 @@ def is_list_of_primitives(field_type):
|
||||||
|
|
||||||
|
|
||||||
def is_basemodel_without_fields(typ):
|
def is_basemodel_without_fields(typ):
|
||||||
return (
|
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def can_recurse(typ):
|
def can_recurse(typ):
|
||||||
return (
|
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_literal_values(field):
|
def get_literal_values(field):
|
||||||
|
@ -72,7 +68,7 @@ def is_discriminated_union(typ) -> bool:
|
||||||
if isinstance(typ, FieldInfo):
|
if isinstance(typ, FieldInfo):
|
||||||
return typ.discriminator
|
return typ.discriminator
|
||||||
else:
|
else:
|
||||||
if not (get_origin(typ) is Annotated):
|
if get_origin(typ) is not Annotated:
|
||||||
return False
|
return False
|
||||||
args = get_args(typ)
|
args = get_args(typ)
|
||||||
return len(args) >= 2 and args[1].discriminator
|
return len(args) >= 2 and args[1].discriminator
|
||||||
|
@ -116,9 +112,7 @@ def prompt_for_discriminated_union(
|
||||||
chosen_type = type_map[discriminator_value]
|
chosen_type = type_map[discriminator_value]
|
||||||
log.info(f"\nConfiguring {chosen_type.__name__}:")
|
log.info(f"\nConfiguring {chosen_type.__name__}:")
|
||||||
|
|
||||||
if existing_value and (
|
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
|
||||||
getattr(existing_value, discriminator) != discriminator_value
|
|
||||||
):
|
|
||||||
existing_value = None
|
existing_value = None
|
||||||
|
|
||||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||||
|
@ -134,9 +128,7 @@ def prompt_for_discriminated_union(
|
||||||
#
|
#
|
||||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||||
# unit tests for coverage.
|
# unit tests for coverage.
|
||||||
def prompt_for_config(
|
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
|
||||||
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
|
|
||||||
) -> BaseModel:
|
|
||||||
"""
|
"""
|
||||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||||
|
|
||||||
|
@ -150,17 +142,11 @@ def prompt_for_config(
|
||||||
|
|
||||||
for field_name, field in config_type.__fields__.items():
|
for field_name, field in config_type.__fields__.items():
|
||||||
field_type = field.annotation
|
field_type = field.annotation
|
||||||
existing_value = (
|
existing_value = getattr(existing_config, field_name) if existing_config else None
|
||||||
getattr(existing_config, field_name) if existing_config else None
|
|
||||||
)
|
|
||||||
if existing_value:
|
if existing_value:
|
||||||
default_value = existing_value
|
default_value = existing_value
|
||||||
else:
|
else:
|
||||||
default_value = (
|
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
|
||||||
field.default
|
|
||||||
if not isinstance(field.default, PydanticUndefinedType)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
is_required = field.is_required
|
is_required = field.is_required
|
||||||
|
|
||||||
# Skip fields with Literal type
|
# Skip fields with Literal type
|
||||||
|
@ -183,15 +169,11 @@ def prompt_for_config(
|
||||||
config_data[field_name] = validated_value
|
config_data[field_name] = validated_value
|
||||||
break
|
break
|
||||||
except KeyError:
|
except KeyError:
|
||||||
log.error(
|
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
|
||||||
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if is_discriminated_union(field):
|
if is_discriminated_union(field):
|
||||||
config_data[field_name] = prompt_for_discriminated_union(
|
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
|
||||||
field_name, field, existing_value
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
||||||
|
@ -202,9 +184,7 @@ def prompt_for_config(
|
||||||
nested_type = get_non_none_type(field_type)
|
nested_type = get_non_none_type(field_type)
|
||||||
log.info(f"Entering sub-configuration for {field_name}:")
|
log.info(f"Entering sub-configuration for {field_name}:")
|
||||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||||
elif is_optional(field_type) and is_discriminated_union(
|
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
|
||||||
get_non_none_type(field_type)
|
|
||||||
):
|
|
||||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||||
if input(prompt).lower() == "n":
|
if input(prompt).lower() == "n":
|
||||||
config_data[field_name] = None
|
config_data[field_name] = None
|
||||||
|
@ -260,16 +240,12 @@ def prompt_for_config(
|
||||||
try:
|
try:
|
||||||
value = json.loads(user_input)
|
value = json.loads(user_input)
|
||||||
if not isinstance(value, list):
|
if not isinstance(value, list):
|
||||||
raise ValueError(
|
raise ValueError("Input must be a JSON-encoded list")
|
||||||
"Input must be a JSON-encoded list"
|
|
||||||
)
|
|
||||||
element_type = get_args(field_type)[0]
|
element_type = get_args(field_type)[0]
|
||||||
value = [element_type(item) for item in value]
|
value = [element_type(item) for item in value]
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
log.error(
|
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
|
||||||
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
log.error(f"{str(e)}")
|
log.error(f"{str(e)}")
|
||||||
|
@ -279,20 +255,14 @@ def prompt_for_config(
|
||||||
try:
|
try:
|
||||||
value = json.loads(user_input)
|
value = json.loads(user_input)
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise ValueError(
|
raise ValueError("Input must be a JSON-encoded dictionary")
|
||||||
"Input must be a JSON-encoded dictionary"
|
|
||||||
)
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
log.error(
|
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
|
||||||
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Convert the input to the correct type
|
# Convert the input to the correct type
|
||||||
elif inspect.isclass(field_type) and issubclass(
|
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
|
||||||
field_type, BaseModel
|
|
||||||
):
|
|
||||||
# For nested BaseModels, we assume a dictionary-like string input
|
# For nested BaseModels, we assume a dictionary-like string input
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
|
@ -301,16 +271,12 @@ def prompt_for_config(
|
||||||
value = field_type(user_input)
|
value = field_type(user_input)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
log.error(
|
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
|
||||||
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate the field using our manual validation function
|
# Validate the field using our manual validation function
|
||||||
validated_value = manually_validate_field(
|
validated_value = manually_validate_field(config_type, field_name, value)
|
||||||
config_type, field_name, value
|
|
||||||
)
|
|
||||||
config_data[field_name] = validated_value
|
config_data[field_name] = validated_value
|
||||||
break
|
break
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
@ -11,9 +11,7 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
|
|
||||||
):
|
|
||||||
from .agents import MetaReferenceAgentsImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
|
|
@ -74,9 +74,7 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
return "".join(
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
|
@ -153,9 +151,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
self, request: AgentTurnCreateRequest
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
@ -206,14 +202,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
chunk, AgentTurnResponseStreamChunk
|
|
||||||
), f"Unexpected type {type(chunk)}"
|
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if (
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
event.payload.event_type
|
|
||||||
== AgentTurnResponseEventType.step_complete.value
|
|
||||||
):
|
|
||||||
steps.append(event.payload.step_details)
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
@ -388,9 +379,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(
|
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||||
session_id, documents, input_messages, tool_defs
|
|
||||||
)
|
|
||||||
|
|
||||||
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
|
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||||
|
@ -408,9 +397,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
vector_db_ids = args.get("vector_db_ids", [])
|
vector_db_ids = args.get("vector_db_ids", [])
|
||||||
query_config = args.get("query_config")
|
query_config = args.get("query_config")
|
||||||
if query_config:
|
if query_config:
|
||||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||||
query_config
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# handle someone passing an empty dict
|
# handle someone passing an empty dict
|
||||||
query_config = RAGQueryConfig()
|
query_config = RAGQueryConfig()
|
||||||
|
@ -438,9 +425,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = await self.tool_runtime_api.rag_tool.query(
|
result = await self.tool_runtime_api.rag_tool.query(
|
||||||
content=concat_interleaved_content(
|
content=concat_interleaved_content([msg.content for msg in input_messages]),
|
||||||
[msg.content for msg in input_messages]
|
|
||||||
),
|
|
||||||
vector_db_ids=vector_db_ids,
|
vector_db_ids=vector_db_ids,
|
||||||
query_config=query_config,
|
query_config=query_config,
|
||||||
)
|
)
|
||||||
|
@ -472,9 +457,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
span.set_attribute(
|
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
|
||||||
)
|
|
||||||
span.set_attribute("output", retrieved_context)
|
span.set_attribute("output", retrieved_context)
|
||||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||||
|
|
||||||
|
@ -511,9 +494,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[
|
tools=[
|
||||||
tool
|
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
||||||
for tool in tool_defs.values()
|
|
||||||
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
|
||||||
],
|
],
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
|
@ -560,12 +541,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if event.stop_reason is not None:
|
if event.stop_reason is not None:
|
||||||
stop_reason = event.stop_reason
|
stop_reason = event.stop_reason
|
||||||
span.set_attribute("stop_reason", stop_reason)
|
span.set_attribute("stop_reason", stop_reason)
|
||||||
span.set_attribute(
|
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
|
||||||
)
|
|
||||||
span.set_attribute(
|
|
||||||
"output", f"content: {content} tool_calls: {tool_calls}"
|
|
||||||
)
|
|
||||||
|
|
||||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||||
|
|
||||||
|
@ -667,9 +644,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
toolgroup_args,
|
toolgroup_args,
|
||||||
tool_to_group,
|
tool_to_group,
|
||||||
)
|
)
|
||||||
assert (
|
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||||
len(result_messages) == 1
|
|
||||||
), "Currently not supporting multiple messages"
|
|
||||||
result_message = result_messages[0]
|
result_message = result_messages[0]
|
||||||
span.set_attribute("output", result_message.model_dump_json())
|
span.set_attribute("output", result_message.model_dump_json())
|
||||||
|
|
||||||
|
@ -697,9 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
|
|
||||||
if out_attachment := _interpret_content_as_attachment(
|
if out_attachment := _interpret_content_as_attachment(result_message.content):
|
||||||
result_message.content
|
|
||||||
):
|
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
|
@ -714,22 +687,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
agent_config_toolgroups = set(
|
agent_config_toolgroups = set(
|
||||||
(
|
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||||
toolgroup.name
|
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
|
||||||
else toolgroup
|
|
||||||
)
|
|
||||||
for toolgroup in self.agent_config.toolgroups
|
for toolgroup in self.agent_config.toolgroups
|
||||||
)
|
)
|
||||||
toolgroups_for_turn_set = (
|
toolgroups_for_turn_set = (
|
||||||
agent_config_toolgroups
|
agent_config_toolgroups
|
||||||
if toolgroups_for_turn is None
|
if toolgroups_for_turn is None
|
||||||
else {
|
else {
|
||||||
(
|
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||||
toolgroup.name
|
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
|
||||||
else toolgroup
|
|
||||||
)
|
|
||||||
for toolgroup in toolgroups_for_turn
|
for toolgroup in toolgroups_for_turn
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -759,10 +724,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
continue
|
continue
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
for tool_def in tools.data:
|
for tool_def in tools.data:
|
||||||
if (
|
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||||
toolgroup_name.startswith("builtin")
|
|
||||||
and toolgroup_name != RAG_TOOL_GROUP
|
|
||||||
):
|
|
||||||
tool_name = tool_def.identifier
|
tool_name = tool_def.identifier
|
||||||
built_in_type = BuiltinTool.brave_search
|
built_in_type = BuiltinTool.brave_search
|
||||||
if tool_name == "web_search":
|
if tool_name == "web_search":
|
||||||
|
@ -773,9 +735,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if tool_def_map.get(built_in_type, None):
|
if tool_def_map.get(built_in_type, None):
|
||||||
raise ValueError(f"Tool {built_in_type} already exists")
|
raise ValueError(f"Tool {built_in_type} already exists")
|
||||||
|
|
||||||
tool_def_map[built_in_type] = ToolDefinition(
|
tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
|
||||||
tool_name=built_in_type
|
|
||||||
)
|
|
||||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -821,9 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||||
if code_interpreter_tool:
|
if code_interpreter_tool:
|
||||||
for c in content_items:
|
for c in content_items:
|
||||||
temp_file_path = os.path.join(
|
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
|
||||||
self.tempdir, f"{make_random_string()}.txt"
|
|
||||||
)
|
|
||||||
with open(temp_file_path, "w") as temp_file:
|
with open(temp_file_path, "w") as temp_file:
|
||||||
temp_file.write(c.content)
|
temp_file.write(c.content)
|
||||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||||
|
@ -849,8 +807,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# we try to load the data from the URLs and content items as a message to inference
|
# we try to load the data from the URLs and content items as a message to inference
|
||||||
# and add it to the last message's context
|
# and add it to the last message's context
|
||||||
input_messages[-1].context = "\n".join(
|
input_messages[-1].context = "\n".join(
|
||||||
[doc.content for doc in content_items]
|
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
|
||||||
+ await load_data_from_urls(url_items)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _ensure_vector_db(self, session_id: str) -> str:
|
async def _ensure_vector_db(self, session_id: str) -> str:
|
||||||
|
@ -874,9 +831,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
return vector_db_id
|
return vector_db_id
|
||||||
|
|
||||||
async def add_to_session_vector_db(
|
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
|
||||||
self, session_id: str, data: List[Document]
|
|
||||||
) -> None:
|
|
||||||
vector_db_id = await self._ensure_vector_db(session_id)
|
vector_db_id = await self._ensure_vector_db(session_id)
|
||||||
documents = [
|
documents = [
|
||||||
RAGDocument(
|
RAGDocument(
|
||||||
|
@ -931,11 +886,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
content.append(
|
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
|
||||||
TextContentItem(
|
|
||||||
text=f'# There is a file accessible to you at "{filepath}"\n'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ToolResponseMessage(
|
return ToolResponseMessage(
|
||||||
call_id="",
|
call_id="",
|
||||||
|
|
|
@ -94,16 +94,12 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
try:
|
try:
|
||||||
agent_config = json.loads(agent_config)
|
agent_config = json.loads(agent_config)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(
|
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
|
||||||
f"Could not JSON decode agent config for {agent_id}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_config = AgentConfig(**agent_config)
|
agent_config = AgentConfig(**agent_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
|
||||||
f"Could not validate(?) agent config for {agent_id}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
return ChatAgent(
|
return ChatAgent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -115,9 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store
|
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
||||||
if agent_config.enable_session_persistence
|
|
||||||
else self.in_memory_store
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -168,22 +162,14 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
self, agent_id: str, session_id: str, turn_id: str
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
) -> Turn:
|
|
||||||
turn = await self.persistence_store.get(
|
|
||||||
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
||||||
)
|
|
||||||
turn = json.loads(turn)
|
turn = json.loads(turn)
|
||||||
turn = Turn(**turn)
|
turn = Turn(**turn)
|
||||||
return turn
|
return turn
|
||||||
|
|
||||||
async def get_agents_step(
|
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||||
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
) -> AgentStepResponse:
|
|
||||||
turn = await self.persistence_store.get(
|
|
||||||
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
||||||
)
|
|
||||||
turn = json.loads(turn)
|
turn = json.loads(turn)
|
||||||
turn = Turn(**turn)
|
turn = Turn(**turn)
|
||||||
steps = turn.steps
|
steps = turn.steps
|
||||||
|
@ -203,9 +189,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
turns = []
|
turns = []
|
||||||
if turn_ids:
|
if turn_ids:
|
||||||
for turn_id in turn_ids:
|
for turn_id in turn_ids:
|
||||||
turn = await self.persistence_store.get(
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
f"session:{agent_id}:{session_id}:{turn_id}"
|
|
||||||
)
|
|
||||||
turn = json.loads(turn)
|
turn = json.loads(turn)
|
||||||
turn = Turn(**turn)
|
turn = Turn(**turn)
|
||||||
turns.append(turn)
|
turns.append(turn)
|
||||||
|
|
|
@ -33,9 +33,7 @@ class ShieldRunnerMixin:
|
||||||
self.input_shields = input_shields
|
self.input_shields = input_shields
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(
|
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||||
self, messages: List[Message], identifiers: List[str]
|
|
||||||
) -> None:
|
|
||||||
responses = await asyncio.gather(
|
responses = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.safety_api.run_shield(
|
self.safety_api.run_shield(
|
||||||
|
|
|
@ -64,9 +64,7 @@ class MockInferenceAPI:
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
async def stream_response():
|
async def stream_response():
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -104,9 +102,7 @@ class MockInferenceAPI:
|
||||||
|
|
||||||
|
|
||||||
class MockSafetyAPI:
|
class MockSafetyAPI:
|
||||||
async def run_shield(
|
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
|
||||||
self, shield_id: str, messages: List[Message]
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
return RunShieldResponse(violation=None)
|
return RunShieldResponse(violation=None)
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,9 +125,7 @@ class MockVectorIOAPI:
|
||||||
|
|
||||||
|
|
||||||
class MockToolGroupsAPI:
|
class MockToolGroupsAPI:
|
||||||
async def register_tool_group(
|
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
|
||||||
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
|
|
||||||
) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
@ -341,26 +335,21 @@ async def test_chat_agent_complex_turn(get_chat_agent):
|
||||||
assert len(responses) > 0
|
assert len(responses) > 0
|
||||||
|
|
||||||
step_types = [
|
step_types = [
|
||||||
response.event.payload.step_type
|
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
|
||||||
for response in responses
|
|
||||||
if hasattr(response.event.payload, "step_type")
|
|
||||||
]
|
]
|
||||||
|
|
||||||
assert StepType.shield_call in step_types, "Shield call step is missing"
|
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||||
assert StepType.inference in step_types, "Inference step is missing"
|
assert StepType.inference in step_types, "Inference step is missing"
|
||||||
|
|
||||||
event_types = [
|
event_types = [
|
||||||
response.event.payload.event_type
|
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
|
||||||
for response in responses
|
|
||||||
if hasattr(response.event.payload, "event_type")
|
|
||||||
]
|
]
|
||||||
assert "turn_start" in event_types, "Start event is missing"
|
assert "turn_start" in event_types, "Start event is missing"
|
||||||
assert "turn_complete" in event_types, "Complete event is missing"
|
assert "turn_complete" in event_types, "Complete event is missing"
|
||||||
|
|
||||||
assert any(
|
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
|
||||||
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
"Turn complete event is missing"
|
||||||
for response in responses
|
)
|
||||||
), "Turn complete event is missing"
|
|
||||||
turn_complete_payload = next(
|
turn_complete_payload = next(
|
||||||
response.event.payload
|
response.event.payload
|
||||||
for response in responses
|
for response in responses
|
||||||
|
@ -380,9 +369,7 @@ async def test_chat_agent_complex_turn(get_chat_agent):
|
||||||
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_chat_agent_tools(
|
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
|
||||||
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
|
|
||||||
):
|
|
||||||
impl = await get_agents_impl
|
impl = await get_agents_impl
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="test_model",
|
model="test_model",
|
||||||
|
|
|
@ -172,9 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
||||||
new_rows_df = pandas.DataFrame(rows)
|
new_rows_df = pandas.DataFrame(rows)
|
||||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||||
dataset_impl.df = pandas.concat(
|
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||||
[dataset_impl.df, new_rows_df], ignore_index=True
|
|
||||||
)
|
|
||||||
|
|
||||||
url = str(dataset_info.dataset_def.url)
|
url = str(dataset_info.dataset_def.url)
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
|
@ -189,12 +187,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
raise ValueError("Data URL must be a base64-encoded CSV")
|
||||||
|
|
||||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
csv_buffer = dataset_impl.df.to_csv(index=False)
|
||||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
|
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
|
||||||
"utf-8"
|
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
|
||||||
)
|
|
||||||
dataset_info.dataset_def.url = URL(
|
|
||||||
uri=f"data:text/csv;base64,{base64_content}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
||||||
|
|
|
@ -91,14 +91,10 @@ class MetaReferenceEvalImpl(
|
||||||
candidate = task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
scoring_functions = task_def.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||||
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
|
|
||||||
)
|
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=(
|
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
|
||||||
-1 if task_config.num_examples is None else task_config.num_examples
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -127,9 +123,7 @@ class MetaReferenceEvalImpl(
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
|
|
||||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||||
session_create_response = await self.agents_api.create_agent_session(
|
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||||
agent_id, f"session-{i}"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
turn_request = dict(
|
turn_request = dict(
|
||||||
|
@ -138,12 +132,7 @@ class MetaReferenceEvalImpl(
|
||||||
messages=input_messages,
|
messages=input_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
turn_response = [
|
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||||
chunk
|
|
||||||
async for chunk in await self.agents_api.create_agent_turn(
|
|
||||||
**turn_request
|
|
||||||
)
|
|
||||||
]
|
|
||||||
final_event = turn_response[-1].event.payload
|
final_event = turn_response[-1].event.payload
|
||||||
|
|
||||||
# check if there's a memory retrieval step and extract the context
|
# check if there's a memory retrieval step and extract the context
|
||||||
|
@ -152,14 +141,10 @@ class MetaReferenceEvalImpl(
|
||||||
if step.step_type == StepType.tool_execution.value:
|
if step.step_type == StepType.tool_execution.value:
|
||||||
for tool_response in step.tool_responses:
|
for tool_response in step.tool_responses:
|
||||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||||
memory_rag_context = " ".join(
|
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||||
x.text for x in tool_response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_generation = {}
|
agent_generation = {}
|
||||||
agent_generation[ColumnName.generated_answer.value] = (
|
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||||
final_event.turn.output_message.content
|
|
||||||
)
|
|
||||||
if memory_rag_context:
|
if memory_rag_context:
|
||||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||||
|
|
||||||
|
@ -171,9 +156,7 @@ class MetaReferenceEvalImpl(
|
||||||
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
candidate = task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
assert (
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||||
candidate.sampling_params.max_tokens is not None
|
|
||||||
), "SamplingParams.max_tokens must be provided"
|
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for x in tqdm(input_rows):
|
for x in tqdm(input_rows):
|
||||||
|
@ -184,15 +167,9 @@ class MetaReferenceEvalImpl(
|
||||||
content=input_content,
|
content=input_content,
|
||||||
sampling_params=candidate.sampling_params,
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
generations.append(
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
{
|
|
||||||
ColumnName.generated_answer.value: response.completion_message.content
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif ColumnName.chat_completion_input.value in x:
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
chat_completion_input_str = str(
|
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
|
||||||
x[ColumnName.chat_completion_input.value]
|
|
||||||
)
|
|
||||||
input_messages = eval(chat_completion_input_str)
|
input_messages = eval(chat_completion_input_str)
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -204,11 +181,7 @@ class MetaReferenceEvalImpl(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=candidate.sampling_params,
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
generations.append(
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
{
|
|
||||||
ColumnName.generated_answer.value: response.completion_message.content
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input row")
|
raise ValueError("Invalid input row")
|
||||||
|
|
||||||
|
@ -230,10 +203,7 @@ class MetaReferenceEvalImpl(
|
||||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||||
|
|
||||||
# scoring with generated_answer
|
# scoring with generated_answer
|
||||||
score_input_rows = [
|
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
|
||||||
input_r | generated_r
|
|
||||||
for input_r, generated_r in zip(input_rows, generations)
|
|
||||||
]
|
|
||||||
|
|
||||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||||
scoring_functions_dict = {
|
scoring_functions_dict = {
|
||||||
|
@ -241,9 +211,7 @@ class MetaReferenceEvalImpl(
|
||||||
for scoring_fn_id in scoring_functions
|
for scoring_fn_id in scoring_functions
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
scoring_functions_dict = {
|
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
|
||||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
|
||||||
}
|
|
||||||
|
|
||||||
score_response = await self.scoring_api.score(
|
score_response = await self.scoring_api.score(
|
||||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||||
|
|
|
@ -40,9 +40,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
repos = [m.huggingface_repo for m in permitted_models]
|
repos = [m.huggingface_repo for m in permitted_models]
|
||||||
if model not in (descriptors + repos):
|
if model not in (descriptors + repos):
|
||||||
model_list = "\n\t".join(repos)
|
model_list = "\n\t".join(repos)
|
||||||
raise ValueError(
|
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
||||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
|
||||||
)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -83,9 +83,7 @@ class TokenResult(BaseModel):
|
||||||
class Llama:
|
class Llama:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(
|
def build(
|
||||||
config: Union[
|
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
|
||||||
],
|
|
||||||
model_id: str,
|
model_id: str,
|
||||||
llama_model: Model,
|
llama_model: Model,
|
||||||
):
|
):
|
||||||
|
@ -150,9 +148,9 @@ class Llama:
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
assert model_parallel_size == len(
|
assert model_parallel_size == len(checkpoints), (
|
||||||
checkpoints
|
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
)
|
||||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
|
@ -168,9 +166,9 @@ class Llama:
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
assert (
|
assert model_args.vocab_size == tokenizer.n_words, (
|
||||||
model_args.vocab_size == tokenizer.n_words
|
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
)
|
||||||
|
|
||||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||||
|
@ -193,10 +191,7 @@ class Llama:
|
||||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
if (
|
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
|
||||||
model_args.quantization_args is not None
|
|
||||||
and model_args.quantization_args.spinquant
|
|
||||||
):
|
|
||||||
# Add a wrapper for adding hadamard transform for spinquant.
|
# Add a wrapper for adding hadamard transform for spinquant.
|
||||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||||
# loading the state dict.
|
# loading the state dict.
|
||||||
|
@ -206,9 +201,7 @@ class Llama:
|
||||||
|
|
||||||
add_hadamard_transform_for_spinquant(model)
|
add_hadamard_transform_for_spinquant(model)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
|
||||||
"Currently int4 and fp8 are the only supported quantization methods."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
|
@ -262,10 +255,7 @@ class Llama:
|
||||||
params = self.model.params
|
params = self.model.params
|
||||||
|
|
||||||
if print_input_tokens:
|
if print_input_tokens:
|
||||||
input_tokens = [
|
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
|
||||||
self.formatter.vision_token if t == 128256 else t
|
|
||||||
for t in model_input.tokens
|
|
||||||
]
|
|
||||||
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
||||||
prompt_tokens = [model_input.tokens]
|
prompt_tokens = [model_input.tokens]
|
||||||
|
|
||||||
|
@ -287,13 +277,11 @@ class Llama:
|
||||||
mask = model_input.vision.mask if model_input.vision is not None else []
|
mask = model_input.vision.mask if model_input.vision is not None else []
|
||||||
|
|
||||||
# the method works for bsz > 1 so add a batch dimension
|
# the method works for bsz > 1 so add a batch dimension
|
||||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
|
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||||
self.model.compute_vision_tokens_masks(
|
|
||||||
batch_images=[images],
|
batch_images=[images],
|
||||||
batch_masks=[mask],
|
batch_masks=[mask],
|
||||||
total_len=total_len,
|
total_len=total_len,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
pad_id = self.tokenizer.pad_id
|
pad_id = self.tokenizer.pad_id
|
||||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||||
|
@ -340,9 +328,7 @@ class Llama:
|
||||||
|
|
||||||
next_token = next_token.reshape(-1)
|
next_token = next_token.reshape(-1)
|
||||||
# only replace token if prompt has already been generated
|
# only replace token if prompt has already been generated
|
||||||
next_token = torch.where(
|
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||||
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
|
||||||
)
|
|
||||||
tokens[:, cur_pos] = next_token
|
tokens[:, cur_pos] = next_token
|
||||||
|
|
||||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||||
|
@ -365,17 +351,11 @@ class Llama:
|
||||||
reduction="none",
|
reduction="none",
|
||||||
ignore_index=pad_id,
|
ignore_index=pad_id,
|
||||||
)
|
)
|
||||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||||
torch.isin(next_token, stop_tokens)
|
|
||||||
)
|
|
||||||
yield TokenResult(
|
yield TokenResult(
|
||||||
token=next_token[0].item(),
|
token=next_token[0].item(),
|
||||||
text=self.tokenizer.decode(next_token.tolist()),
|
text=self.tokenizer.decode(next_token.tolist()),
|
||||||
logprobs=(
|
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
|
||||||
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
|
||||||
if logprobs
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_pos = cur_pos
|
prev_pos = cur_pos
|
||||||
|
@ -388,11 +368,7 @@ class Llama:
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if (
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
|
||||||
max_gen_len is None
|
|
||||||
or max_gen_len == 0
|
|
||||||
or max_gen_len >= self.model.params.max_seq_len
|
|
||||||
):
|
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
model_input = self.formatter.encode_content(request.content)
|
model_input = self.formatter.encode_content(request.content)
|
||||||
|
@ -417,11 +393,7 @@ class Llama:
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if (
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
|
||||||
max_gen_len is None
|
|
||||||
or max_gen_len == 0
|
|
||||||
or max_gen_len >= self.model.params.max_seq_len
|
|
||||||
):
|
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
|
@ -473,9 +445,7 @@ class LogitsProcessor:
|
||||||
self.token_enforcer = token_enforcer
|
self.token_enforcer = token_enforcer
|
||||||
self.mask: Optional[torch.Tensor] = None
|
self.mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def process_logits(
|
def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
self, tokens: torch.Tensor, scores: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
token_sequence = tokens[0, :].tolist()
|
token_sequence = tokens[0, :].tolist()
|
||||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||||
|
|
||||||
|
@ -510,9 +480,7 @@ def get_logits_processor(
|
||||||
return LogitsProcessor(token_enforcer)
|
return LogitsProcessor(token_enforcer)
|
||||||
|
|
||||||
|
|
||||||
def _build_regular_tokens_list(
|
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
|
||||||
tokenizer: Tokenizer, vocab_size: int
|
|
||||||
) -> List[Tuple[int, str, bool]]:
|
|
||||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||||
regular_tokens = []
|
regular_tokens = []
|
||||||
|
|
||||||
|
|
|
@ -80,9 +80,7 @@ class MetaReferenceInferenceImpl(
|
||||||
async def load_model(self, model_id, llama_model) -> None:
|
async def load_model(self, model_id, llama_model) -> None:
|
||||||
log.info(f"Loading model `{model_id}`")
|
log.info(f"Loading model `{model_id}`")
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(
|
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||||
self.config, model_id, llama_model
|
|
||||||
)
|
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
self.generator = Llama.build(self.config, model_id, llama_model)
|
self.generator = Llama.build(self.config, model_id, llama_model)
|
||||||
|
@ -100,9 +98,7 @@ class MetaReferenceInferenceImpl(
|
||||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||||
)
|
)
|
||||||
elif request.model != self.model_id:
|
elif request.model != self.model_id:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||||
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -184,13 +180,7 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs = [
|
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
||||||
TokenLogProbs(
|
|
||||||
logprobs_by_token={
|
|
||||||
token_result.text: token_result.logprobs[0]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=text,
|
delta=text,
|
||||||
|
@ -212,9 +202,7 @@ class MetaReferenceInferenceImpl(
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
self, request: CompletionRequest
|
|
||||||
) -> CompletionResponse:
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
@ -231,13 +219,7 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs.append(
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||||
TokenLogProbs(
|
|
||||||
logprobs_by_token={
|
|
||||||
token_result.text: token_result.logprobs[0]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
@ -289,9 +271,7 @@ class MetaReferenceInferenceImpl(
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
# augment and rewrite messages depending on the model
|
# augment and rewrite messages depending on the model
|
||||||
request.messages = chat_completion_request_to_messages(
|
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||||
request, self.llama_model.core_model_id.value
|
|
||||||
)
|
|
||||||
# download media and convert to raw content so we can send it to the model
|
# download media and convert to raw content so we can send it to the model
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
@ -304,9 +284,7 @@ class MetaReferenceInferenceImpl(
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> ChatCompletionResponse:
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
@ -323,20 +301,12 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs.append(
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||||
TokenLogProbs(
|
|
||||||
logprobs_by_token={
|
|
||||||
token_result.text: token_result.logprobs[0]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
raw_message = self.generator.formatter.decode_assistant_message(
|
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||||
tokens, stop_reason
|
|
||||||
)
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
content=raw_message.content,
|
content=raw_message.content,
|
||||||
|
@ -352,9 +322,7 @@ class MetaReferenceInferenceImpl(
|
||||||
else:
|
else:
|
||||||
return impl()
|
return impl()
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
def impl():
|
def impl():
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -405,13 +373,7 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs.append(
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||||
TokenLogProbs(
|
|
||||||
logprobs_by_token={
|
|
||||||
token_result.text: token_result.logprobs[0]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -424,9 +386,7 @@ class MetaReferenceInferenceImpl(
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
message = self.generator.formatter.decode_assistant_message(
|
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||||
tokens, stop_reason
|
|
||||||
)
|
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
|
|
|
@ -91,9 +91,7 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
self.group = ModelParallelProcessGroup(
|
self.group = ModelParallelProcessGroup(
|
||||||
model_parallel_size,
|
model_parallel_size,
|
||||||
init_model_cb=partial(
|
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
|
||||||
init_model_cb, self.config, self.model_id, self.llama_model
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.group.start()
|
self.group.start()
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -55,47 +55,33 @@ class ProcessingMessageName(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class ReadyRequest(BaseModel):
|
class ReadyRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.ready_request] = (
|
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
|
||||||
ProcessingMessageName.ready_request
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReadyResponse(BaseModel):
|
class ReadyResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.ready_response] = (
|
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
|
||||||
ProcessingMessageName.ready_response
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EndSentinel(BaseModel):
|
class EndSentinel(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.end_sentinel] = (
|
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
|
||||||
ProcessingMessageName.end_sentinel
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CancelSentinel(BaseModel):
|
class CancelSentinel(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.cancel_sentinel] = (
|
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
|
||||||
ProcessingMessageName.cancel_sentinel
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = (
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
ProcessingMessageName.task_request
|
|
||||||
)
|
|
||||||
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_response] = (
|
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||||
ProcessingMessageName.task_response
|
|
||||||
)
|
|
||||||
result: TokenResult
|
result: TokenResult
|
||||||
|
|
||||||
|
|
||||||
class ExceptionResponse(BaseModel):
|
class ExceptionResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.exception_response] = (
|
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
|
||||||
ProcessingMessageName.exception_response
|
|
||||||
)
|
|
||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,9 +175,7 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
group=get_model_parallel_group(),
|
group=get_model_parallel_group(),
|
||||||
)
|
)
|
||||||
if isinstance(updates[0], CancelSentinel):
|
if isinstance(updates[0], CancelSentinel):
|
||||||
log.info(
|
log.info("quitting generation loop because request was cancelled")
|
||||||
"quitting generation loop because request was cancelled"
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
|
@ -350,9 +334,7 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: Union[
|
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
|
||||||
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
|
|
||||||
],
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,7 @@ try:
|
||||||
|
|
||||||
log.info("Using efficient FP8 operators in FBGEMM.")
|
log.info("Using efficient FP8 operators in FBGEMM.")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log.error(
|
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||||
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -60,14 +58,8 @@ def ffn_swiglu(
|
||||||
num_tokens: Optional[Tensor] = None,
|
num_tokens: Optional[Tensor] = None,
|
||||||
is_memory_bounded: bool = False,
|
is_memory_bounded: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if (
|
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
|
||||||
isinstance(w1, Fp8ScaledWeights)
|
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
||||||
and isinstance(w3, Fp8ScaledWeights)
|
|
||||||
and isinstance(w2, Fp8ScaledWeights)
|
|
||||||
):
|
|
||||||
return ffn_swiglu_fp8_dynamic(
|
|
||||||
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
|
|
||||||
)
|
|
||||||
|
|
||||||
(B, T, D) = x.shape # noqa: N806
|
(B, T, D) = x.shape # noqa: N806
|
||||||
(HD_L, D_) = w1.shape # noqa: N806
|
(HD_L, D_) = w1.shape # noqa: N806
|
||||||
|
@ -146,12 +138,8 @@ def fc_fp8_dynamic(
|
||||||
Single w8a8 fc layer with dynamic row-wise scaling.
|
Single w8a8 fc layer with dynamic row-wise scaling.
|
||||||
"""
|
"""
|
||||||
if isinstance(w, Fp8RowwiseWeights):
|
if isinstance(w, Fp8RowwiseWeights):
|
||||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
||||||
x, num_tokens, activation_scale_ub
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
||||||
)
|
|
||||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
|
||||||
xq, w.weight, x_scale, w.scale, use_fast_accum=True
|
|
||||||
)
|
|
||||||
del xq
|
del xq
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,7 @@ from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not torch.cuda.is_available()
|
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||||
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
|
||||||
"Skip when H100 is not available",
|
"Skip when H100 is not available",
|
||||||
)
|
)
|
||||||
class FP8Tests(unittest.TestCase):
|
class FP8Tests(unittest.TestCase):
|
||||||
|
|
|
@ -57,9 +57,7 @@ class HadamardModule(torch.nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def add_hadamard_transform_for_spinquant(
|
def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
|
||||||
model: torch.nn.Module, prefix: str = ""
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
|
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
|
||||||
This function recursively traverses the model's children and looks for layers that match the pattern
|
This function recursively traverses the model's children and looks for layers that match the pattern
|
||||||
|
@ -81,12 +79,8 @@ def add_hadamard_transform_for_spinquant(
|
||||||
for module_name, module in model.named_children():
|
for module_name, module in model.named_children():
|
||||||
child_full_name = prefix + "." + module_name
|
child_full_name = prefix + "." + module_name
|
||||||
if re.search(pattern_last_linear_ffn, child_full_name):
|
if re.search(pattern_last_linear_ffn, child_full_name):
|
||||||
new_module = nn.Sequential(
|
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
|
||||||
HadamardModule(group_size=module.in_features), module
|
|
||||||
)
|
|
||||||
del module
|
del module
|
||||||
setattr(model, module_name, new_module)
|
setattr(model, module_name, new_module)
|
||||||
else:
|
else:
|
||||||
add_hadamard_transform_for_spinquant(
|
add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)
|
||||||
module, (prefix + "." if prefix else prefix) + module_name
|
|
||||||
)
|
|
||||||
|
|
|
@ -63,12 +63,8 @@ def convert_to_fp8_quantized_model(
|
||||||
# Move weights to GPU with quantization
|
# Move weights to GPU with quantization
|
||||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||||
log.info("Loading fp8 scales...")
|
log.info("Loading fp8 scales...")
|
||||||
fp8_scales_path = os.path.join(
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||||
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||||
)
|
|
||||||
assert os.path.isfile(
|
|
||||||
fp8_scales_path
|
|
||||||
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
|
||||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||||
|
|
||||||
for block in model.layers:
|
for block in model.layers:
|
||||||
|
@ -81,9 +77,7 @@ def convert_to_fp8_quantized_model(
|
||||||
param = getattr(block.feed_forward, key)
|
param = getattr(block.feed_forward, key)
|
||||||
param.weight = load_fp8(
|
param.weight = load_fp8(
|
||||||
param.weight,
|
param.weight,
|
||||||
fp8_scales[
|
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
|
||||||
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
|
|
||||||
],
|
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -172,9 +166,7 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||||
if prefix + "zeros" not in state_dict:
|
if prefix + "zeros" not in state_dict:
|
||||||
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
||||||
assert prefix + "scales" in state_dict
|
assert prefix + "scales" in state_dict
|
||||||
state_dict[prefix + "zeros"] = torch.zeros_like(
|
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
|
||||||
state_dict[prefix + "scales"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
module_out = super().forward(input_)
|
module_out = super().forward(input_)
|
||||||
|
@ -229,9 +221,7 @@ class Int8WeightLinear(torch.nn.Linear):
|
||||||
bias: Whether to use bias.
|
bias: Whether to use bias.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
|
||||||
self, in_features: int, out_features: int, bias: bool = True, device=None
|
|
||||||
) -> None:
|
|
||||||
super().__init__(in_features, out_features, bias, device=device)
|
super().__init__(in_features, out_features, bias, device=device)
|
||||||
|
|
||||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||||
|
@ -295,9 +285,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||||
del module
|
del module
|
||||||
setattr(model, module_name, quantized_module)
|
setattr(model, module_name, quantized_module)
|
||||||
else:
|
else:
|
||||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
|
||||||
module, group_size, lora_rank, lora_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -321,9 +309,7 @@ def convert_to_int4_quantized_model(
|
||||||
|
|
||||||
group_size = model_args.quantization_args.group_size
|
group_size = model_args.quantization_args.group_size
|
||||||
if group_size is None:
|
if group_size is None:
|
||||||
raise ValueError(
|
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
|
||||||
"'group_size' cannot be None in 'quantization_args'. Please specify it."
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_args.lora_args is None:
|
if model_args.lora_args is None:
|
||||||
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
||||||
|
@ -333,8 +319,6 @@ def convert_to_int4_quantized_model(
|
||||||
lora_rank = model_args.lora_args.rank
|
lora_rank = model_args.lora_args.rank
|
||||||
lora_scale = model_args.lora_args.scale
|
lora_scale = model_args.lora_args.scale
|
||||||
|
|
||||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||||
model, group_size, lora_rank, lora_scale
|
|
||||||
)
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|
|
@ -76,9 +76,9 @@ def main(
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
assert model_parallel_size == len(
|
assert model_parallel_size == len(checkpoints), (
|
||||||
checkpoints
|
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
)
|
||||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
|
@ -90,9 +90,9 @@ def main(
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||||
assert (
|
assert model_args.vocab_size == tokenizer.n_words, (
|
||||||
model_args.vocab_size == tokenizer.n_words
|
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
)
|
||||||
|
|
||||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
@ -106,9 +106,7 @@ def main(
|
||||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
|
||||||
log.info(ckpt_path)
|
log.info(ckpt_path)
|
||||||
assert (
|
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
|
||||||
quantized_ckpt_dir is not None
|
|
||||||
), "QUantized checkpoint directory should not be None"
|
|
||||||
fp8_scales = {}
|
fp8_scales = {}
|
||||||
for block in model.layers:
|
for block in model.layers:
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
|
@ -122,9 +120,7 @@ def main(
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
||||||
fp8_scales[
|
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||||
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
|
|
||||||
] = fp8_weight.scale
|
|
||||||
|
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w3.weight,
|
block.feed_forward.w3.weight,
|
||||||
|
@ -133,9 +129,7 @@ def main(
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
||||||
fp8_scales[
|
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||||
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
|
|
||||||
] = fp8_weight.scale
|
|
||||||
|
|
||||||
fp8_weight = quantize_fp8(
|
fp8_weight = quantize_fp8(
|
||||||
block.feed_forward.w2.weight,
|
block.feed_forward.w2.weight,
|
||||||
|
@ -144,13 +138,9 @@ def main(
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
||||||
fp8_scales[
|
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||||
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
|
|
||||||
] = fp8_weight.scale
|
|
||||||
|
|
||||||
fp8_scales_path = os.path.join(
|
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||||
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
|
||||||
)
|
|
||||||
torch.save(fp8_scales, fp8_scales_path)
|
torch.save(fp8_scales, fp8_scales_path)
|
||||||
|
|
||||||
ckpt_path = os.path.join(
|
ckpt_path = os.path.join(
|
||||||
|
|
|
@ -10,7 +10,6 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformersInferenceConfig(BaseModel):
|
class SentenceTransformersInferenceConfig(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls) -> Dict[str, Any]:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -53,7 +53,5 @@ class VLLMConfig(BaseModel):
|
||||||
repos = [m.huggingface_repo for m in permitted_models]
|
repos = [m.huggingface_repo for m in permitted_models]
|
||||||
if model not in (descriptors + repos):
|
if model not in (descriptors + repos):
|
||||||
model_list = "\n\t".join(repos)
|
model_list = "\n\t".join(repos)
|
||||||
raise ValueError(
|
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
||||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
|
||||||
)
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -176,13 +176,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
log.info("Sampling params: %s", sampling_params)
|
log.info("Sampling params: %s", sampling_params)
|
||||||
request_id = _random_uuid()
|
request_id = _random_uuid()
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(
|
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
|
||||||
request, self.config.model, self.formatter
|
|
||||||
)
|
|
||||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||||
results_generator = self.engine.generate(
|
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
||||||
prompt, vllm_sampling_params, request_id
|
|
||||||
)
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, results_generator)
|
return self._stream_chat_completion(request, results_generator)
|
||||||
else:
|
else:
|
||||||
|
@ -230,12 +226,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(
|
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
||||||
stream, self.formatter
|
|
||||||
):
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||||
self, model_id: str, contents: List[InterleavedContent]
|
|
||||||
) -> EmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -47,6 +47,4 @@ async def validate_input_dataset_schema(
|
||||||
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
||||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||||
|
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])
|
||||||
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
|
|
||||||
)
|
|
||||||
|
|
|
@ -42,9 +42,7 @@ class TorchtuneCheckpointer:
|
||||||
self._model_type = ModelType[model_type]
|
self._model_type = ModelType[model_type]
|
||||||
self._output_dir = output_dir
|
self._output_dir = output_dir
|
||||||
# get ckpt paths
|
# get ckpt paths
|
||||||
self._checkpoint_path = Path.joinpath(
|
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||||
self._checkpoint_dir, self._checkpoint_file
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_checkpoint(self) -> Dict[str, Any]:
|
def load_checkpoint(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -57,13 +55,9 @@ class TorchtuneCheckpointer:
|
||||||
llama3_vision_meta_to_tune,
|
llama3_vision_meta_to_tune,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
|
||||||
model_state_dict
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
|
||||||
model_state_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# llama3_2 has tied weights, so we need to remove the output.weight key
|
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||||
if self._model_type == ModelType.LLAMA3_2:
|
if self._model_type == ModelType.LLAMA3_2:
|
||||||
|
@ -82,10 +76,7 @@ class TorchtuneCheckpointer:
|
||||||
epoch: int,
|
epoch: int,
|
||||||
adapter_only: bool = False,
|
adapter_only: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
model_file_path = (
|
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||||
Path(self._output_dir)
|
|
||||||
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -116,22 +107,13 @@ class TorchtuneCheckpointer:
|
||||||
llama3_vision_tune_to_meta,
|
llama3_vision_tune_to_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
|
||||||
model_state_dict
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||||
if (
|
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
|
||||||
self._model_type == ModelType.LLAMA3_2
|
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
|
||||||
and "output.weight" not in model_state_dict
|
|
||||||
):
|
|
||||||
model_state_dict["output.weight"] = model_state_dict[
|
|
||||||
"tok_embeddings.weight"
|
|
||||||
]
|
|
||||||
|
|
||||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
|
||||||
model_state_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||||
|
|
||||||
|
|
|
@ -15,18 +15,13 @@ from typing import Any, Mapping
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||||
|
|
||||||
|
|
||||||
def llama_stack_instruct_to_torchtune_instruct(
|
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
sample: Mapping[str, Any]
|
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
|
||||||
) -> Mapping[str, Any]:
|
"Invalid input row"
|
||||||
assert (
|
)
|
||||||
ColumnName.chat_completion_input.value in sample
|
|
||||||
and ColumnName.expected_answer.value in sample
|
|
||||||
), "Invalid input row"
|
|
||||||
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
||||||
|
|
||||||
assert (
|
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
|
||||||
len(input_messages) == 1
|
|
||||||
), "llama stack intruct dataset format only supports 1 user message"
|
|
||||||
input_message = input_messages[0]
|
input_message = input_messages[0]
|
||||||
|
|
||||||
assert "content" in input_message, "content not found in input message"
|
assert "content" in input_message, "content not found in input message"
|
||||||
|
@ -48,13 +43,9 @@ def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str
|
||||||
roles = []
|
roles = []
|
||||||
conversations = []
|
conversations = []
|
||||||
for message in dialog:
|
for message in dialog:
|
||||||
assert (
|
assert "role" in message and "content" in message, "role and content must in message"
|
||||||
"role" in message and "content" in message
|
|
||||||
), "role and content must in message"
|
|
||||||
roles.append(message["role"])
|
roles.append(message["role"])
|
||||||
conversations.append(
|
conversations.append({"from": role_map[message["role"]], "value": message["content"]})
|
||||||
{"from": role_map[message["role"]], "value": message["content"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert roles[0] == "user", "first message must be from user"
|
assert roles[0] == "user", "first message must be from user"
|
||||||
assert "assistant" in roles, "at least 1 message should be from assistant"
|
assert "assistant" in roles, "at least 1 message should be from assistant"
|
||||||
|
|
|
@ -61,8 +61,7 @@ class SFTDataset(Dataset):
|
||||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||||
keys_str = ", ".join(tokenized_dict.keys())
|
keys_str = ", ".join(tokenized_dict.keys())
|
||||||
error_message = (
|
error_message = (
|
||||||
"model_transform returned the following keys: "
|
f"model_transform returned the following keys: {keys_str}. Must return 'tokens' and 'mask' as keys."
|
||||||
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
|
|
||||||
)
|
)
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
|
|
@ -119,9 +119,7 @@ class TorchtunePostTrainingImpl:
|
||||||
return ListPostTrainingJobsResponse(data=self.jobs_list)
|
return ListPostTrainingJobsResponse(data=self.jobs_list)
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||||
self, job_uuid: str
|
|
||||||
) -> Optional[PostTrainingJobStatusResponse]:
|
|
||||||
if job_uuid in self.jobs_status:
|
if job_uuid in self.jobs_status:
|
||||||
return self.jobs_status[job_uuid]
|
return self.jobs_status[job_uuid]
|
||||||
return None
|
return None
|
||||||
|
@ -131,12 +129,8 @@ class TorchtunePostTrainingImpl:
|
||||||
raise NotImplementedError("Job cancel is not implemented yet")
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
async def get_training_job_artifacts(
|
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||||
self, job_uuid: str
|
|
||||||
) -> Optional[PostTrainingJobArtifactsResponse]:
|
|
||||||
if job_uuid in self.checkpoints_dict:
|
if job_uuid in self.checkpoints_dict:
|
||||||
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||||
return PostTrainingJobArtifactsResponse(
|
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
|
||||||
job_uuid=job_uuid, checkpoints=checkpoints
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -94,9 +94,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self.job_uuid = job_uuid
|
self.job_uuid = job_uuid
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
raise ValueError(
|
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||||
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
|
||||||
)
|
|
||||||
self.algorithm_config = algorithm_config
|
self.algorithm_config = algorithm_config
|
||||||
self._device = torchtune_utils.get_device(device="cuda")
|
self._device = torchtune_utils.get_device(device="cuda")
|
||||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||||
|
@ -105,10 +103,7 @@ class LoraFinetuningSingleDevice:
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
|
||||||
paths = [
|
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||||
Path(checkpoint_dir / f"consolidated.{ext}")
|
|
||||||
for ext in ["pth", "00.pth"]
|
|
||||||
]
|
|
||||||
if not any(p.exists() for p in paths):
|
if not any(p.exists() for p in paths):
|
||||||
checkpoint_dir = checkpoint_dir / "original"
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
@ -123,9 +118,7 @@ class LoraFinetuningSingleDevice:
|
||||||
else:
|
else:
|
||||||
model = resolve_model(self.model_id)
|
model = resolve_model(self.model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||||
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
|
|
||||||
)
|
|
||||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||||
|
@ -196,9 +189,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._tokenizer = await self._setup_tokenizer()
|
self._tokenizer = await self._setup_tokenizer()
|
||||||
log.info("Tokenizer is initialized.")
|
log.info("Tokenizer is initialized.")
|
||||||
|
|
||||||
self._optimizer = await self._setup_optimizer(
|
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
||||||
optimizer_config=self.training_config.optimizer_config
|
|
||||||
)
|
|
||||||
log.info("Optimizer is initialized.")
|
log.info("Optimizer is initialized.")
|
||||||
|
|
||||||
self._loss_fn = CEWithChunkedOutputLoss()
|
self._loss_fn = CEWithChunkedOutputLoss()
|
||||||
|
@ -226,13 +217,8 @@ class LoraFinetuningSingleDevice:
|
||||||
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||||
# for logging and tracking training state. This should be computed after the dataloader
|
# for logging and tracking training state. This should be computed after the dataloader
|
||||||
# has been setup
|
# has been setup
|
||||||
self._steps_per_epoch = (
|
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||||
len(self._training_dataloader) // self._gradient_accumulation_steps
|
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.max_steps_per_epoch is not None
|
|
||||||
and self.max_steps_per_epoch < self._steps_per_epoch
|
|
||||||
):
|
|
||||||
self._steps_per_epoch = self.max_steps_per_epoch
|
self._steps_per_epoch = self.max_steps_per_epoch
|
||||||
self.global_step = self.epochs_run * self._steps_per_epoch
|
self.global_step = self.epochs_run * self._steps_per_epoch
|
||||||
|
|
||||||
|
@ -246,9 +232,7 @@ class LoraFinetuningSingleDevice:
|
||||||
log.info("Learning rate scheduler is initialized.")
|
log.info("Learning rate scheduler is initialized.")
|
||||||
|
|
||||||
# Used to ignore labels for loss computation
|
# Used to ignore labels for loss computation
|
||||||
self.ignore_labels_cache = torch.full(
|
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||||
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _setup_model(
|
async def _setup_model(
|
||||||
self,
|
self,
|
||||||
|
@ -282,13 +266,9 @@ class LoraFinetuningSingleDevice:
|
||||||
set_trainable_params(model, self.adapter_params)
|
set_trainable_params(model, self.adapter_params)
|
||||||
|
|
||||||
if enable_activation_checkpointing:
|
if enable_activation_checkpointing:
|
||||||
training.set_activation_checkpointing(
|
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
|
||||||
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
|
|
||||||
)
|
|
||||||
|
|
||||||
base_missing, base_unexpected = model.load_state_dict(
|
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
|
||||||
base_model_state_dict, strict=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is for any adapters that need to be initialized after base weights
|
# This is for any adapters that need to be initialized after base weights
|
||||||
# have been loaded (e.g. DoRA).
|
# have been loaded (e.g. DoRA).
|
||||||
|
@ -297,9 +277,7 @@ class LoraFinetuningSingleDevice:
|
||||||
if hasattr(m, "initialize_dora_magnitude"):
|
if hasattr(m, "initialize_dora_magnitude"):
|
||||||
m.initialize_dora_magnitude()
|
m.initialize_dora_magnitude()
|
||||||
if lora_weights_state_dict:
|
if lora_weights_state_dict:
|
||||||
lora_missing, lora_unexpected = model.load_state_dict(
|
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
|
||||||
lora_weights_state_dict, strict=False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
lora_missing, lora_unexpected = None, None
|
lora_missing, lora_unexpected = None, None
|
||||||
validate_missing_and_unexpected_for_lora(
|
validate_missing_and_unexpected_for_lora(
|
||||||
|
@ -313,14 +291,10 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate model adapter params were loaded in with the expected dtype
|
# Validate model adapter params were loaded in with the expected dtype
|
||||||
training.validate_expected_param_dtype(
|
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
|
||||||
self.adapter_params.items(), dtype=self._dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# activation offloading
|
# activation offloading
|
||||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||||
model, enable_activation_offloading
|
|
||||||
)
|
|
||||||
|
|
||||||
memory_stats = training.get_memory_stats(device=self._device)
|
memory_stats = training.get_memory_stats(device=self._device)
|
||||||
training.log_memory_stats(memory_stats)
|
training.log_memory_stats(memory_stats)
|
||||||
|
@ -456,9 +430,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# Shift labels to compute loss
|
# Shift labels to compute loss
|
||||||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||||
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||||
labels = torch.hstack(
|
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
|
||||||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
|
||||||
)
|
|
||||||
if not isinstance(logits, list):
|
if not isinstance(logits, list):
|
||||||
labels = labels.reshape(-1)
|
labels = labels.reshape(-1)
|
||||||
logits = logits.reshape(-1, logits.size(-1))
|
logits = logits.reshape(-1, logits.size(-1))
|
||||||
|
@ -487,9 +459,7 @@ class LoraFinetuningSingleDevice:
|
||||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||||
# in case shuffle is True
|
# in case shuffle is True
|
||||||
metric_logger = DiskLogger(
|
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
|
||||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
|
||||||
)
|
|
||||||
self._training_sampler.set_epoch(curr_epoch)
|
self._training_sampler.set_epoch(curr_epoch)
|
||||||
loss_to_log = 0.0
|
loss_to_log = 0.0
|
||||||
|
|
||||||
|
@ -497,8 +467,7 @@ class LoraFinetuningSingleDevice:
|
||||||
for idx, batch in enumerate(self._training_dataloader):
|
for idx, batch in enumerate(self._training_dataloader):
|
||||||
if (
|
if (
|
||||||
self.max_steps_per_epoch is not None
|
self.max_steps_per_epoch is not None
|
||||||
and (idx // self._gradient_accumulation_steps)
|
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
|
||||||
== self.max_steps_per_epoch
|
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -506,9 +475,7 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
# Calculate the number of unmasked tokens in the current batch
|
# Calculate the number of unmasked tokens in the current batch
|
||||||
# and increment the total number of tokens seen in the step
|
# and increment the total number of tokens seen in the step
|
||||||
current_num_tokens = (
|
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||||
batch["labels"] != self._loss_fn.ignore_index
|
|
||||||
).sum()
|
|
||||||
num_tokens += current_num_tokens
|
num_tokens += current_num_tokens
|
||||||
|
|
||||||
# Loss is normalized by default so we multiply by the number of tokens
|
# Loss is normalized by default so we multiply by the number of tokens
|
||||||
|
@ -533,9 +500,7 @@ class LoraFinetuningSingleDevice:
|
||||||
loss_to_log = running_loss.item() / num_tokens
|
loss_to_log = running_loss.item() / num_tokens
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
pbar.set_description(
|
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||||
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
|
||||||
)
|
|
||||||
|
|
||||||
time_per_step = time.perf_counter() - t0
|
time_per_step = time.perf_counter() - t0
|
||||||
log_dict = {
|
log_dict = {
|
||||||
|
|
|
@ -67,10 +67,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
violation = SafetyViolation(
|
violation = SafetyViolation(
|
||||||
violation_level=(ViolationLevel.ERROR),
|
violation_level=(ViolationLevel.ERROR),
|
||||||
user_message="Sorry, I found security concerns in the code.",
|
user_message="Sorry, I found security concerns in the code.",
|
||||||
metadata={
|
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||||
"violation_type": ",".join(
|
|
||||||
[issue.pattern_id for issue in result.issues_found]
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
|
@ -10,9 +10,7 @@ from .config import LlamaGuardConfig
|
||||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||||
from .llama_guard import LlamaGuardSafetyImpl
|
from .llama_guard import LlamaGuardSafetyImpl
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||||
config, LlamaGuardConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = LlamaGuardSafetyImpl(config, deps)
|
impl = LlamaGuardSafetyImpl(config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
|
|
@ -102,8 +102,7 @@ LLAMA_GUARD_MODEL_IDS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||||
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
|
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
|
||||||
+ [CAT_CODE_INTERPRETER_ABUSE],
|
|
||||||
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||||
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||||
}
|
}
|
||||||
|
@ -133,9 +132,7 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
|
||||||
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||||
|
|
||||||
|
|
||||||
PROMPT_TEMPLATE = Template(
|
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
||||||
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
@ -233,9 +230,7 @@ class LlamaGuardShield:
|
||||||
if messages[0].role != Role.user.value:
|
if messages[0].role != Role.user.value:
|
||||||
raise ValueError("Messages must start with user")
|
raise ValueError("Messages must start with user")
|
||||||
|
|
||||||
if len(messages) >= 2 and (
|
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
||||||
messages[0].role == Role.user.value and messages[1].role == Role.user.value
|
|
||||||
):
|
|
||||||
messages = messages[1:]
|
messages = messages[1:]
|
||||||
|
|
||||||
for i in range(1, len(messages)):
|
for i in range(1, len(messages)):
|
||||||
|
@ -263,10 +258,7 @@ class LlamaGuardShield:
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if (
|
if event.event_type == ChatCompletionResponseEventType.progress and event.delta.type == "text":
|
||||||
event.event_type == ChatCompletionResponseEventType.progress
|
|
||||||
and event.delta.type == "text"
|
|
||||||
):
|
|
||||||
content += event.delta.text
|
content += event.delta.text
|
||||||
|
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
|
@ -313,10 +305,7 @@ class LlamaGuardShield:
|
||||||
categories = self.get_safety_categories()
|
categories = self.get_safety_categories()
|
||||||
categories_str = "\n".join(categories)
|
categories_str = "\n".join(categories)
|
||||||
conversations_str = "\n\n".join(
|
conversations_str = "\n\n".join(
|
||||||
[
|
[f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}" for m in messages]
|
||||||
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
|
|
||||||
for m in messages
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return PROMPT_TEMPLATE.substitute(
|
return PROMPT_TEMPLATE.substitute(
|
||||||
agent_type=messages[-1].role.capitalize(),
|
agent_type=messages[-1].role.capitalize(),
|
||||||
|
|
|
@ -46,9 +46,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||||
raise ValueError(
|
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
||||||
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -71,9 +69,7 @@ class PromptGuardShield:
|
||||||
threshold: float = 0.9,
|
threshold: float = 0.9,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
assert (
|
assert model_dir is not None, "Must provide a model directory for prompt injection shield"
|
||||||
model_dir is not None
|
|
||||||
), "Must provide a model directory for prompt injection shield"
|
|
||||||
if temperature <= 0:
|
if temperature <= 0:
|
||||||
raise ValueError("Temperature must be greater than 0")
|
raise ValueError("Temperature must be greater than 0")
|
||||||
|
|
||||||
|
@ -85,9 +81,7 @@ class PromptGuardShield:
|
||||||
|
|
||||||
# load model and tokenizer
|
# load model and tokenizer
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||||
model_dir, device_map=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
message = messages[-1]
|
message = messages[-1]
|
||||||
|
@ -117,10 +111,7 @@ class PromptGuardShield:
|
||||||
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
elif (
|
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
|
||||||
self.config.guard_type == PromptGuardType.jailbreak.value
|
|
||||||
and score_malicious > self.threshold
|
|
||||||
):
|
|
||||||
violation = SafetyViolation(
|
violation = SafetyViolation(
|
||||||
violation_level=ViolationLevel.ERROR,
|
violation_level=ViolationLevel.ERROR,
|
||||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||||
|
|
|
@ -54,15 +54,11 @@ class BasicScoringImpl(
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
scoring_fn_defs_list = [
|
scoring_fn_defs_list = [
|
||||||
fn_def
|
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||||
for impl in self.scoring_fn_id_impls.values()
|
|
||||||
for fn_def in impl.get_supported_scoring_fn_defs()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for f in scoring_fn_defs_list:
|
for f in scoring_fn_defs_list:
|
||||||
assert f.identifier.startswith(
|
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
||||||
"basic"
|
|
||||||
), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
|
||||||
|
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
|
@ -76,9 +72,7 @@ class BasicScoringImpl(
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
@ -108,12 +102,8 @@ class BasicScoringImpl(
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||||
score_results = await scoring_fn.score(
|
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||||
input_rows, scoring_fn_id, scoring_fn_params
|
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||||
)
|
|
||||||
agg_results = await scoring_fn.aggregate(
|
|
||||||
score_results, scoring_fn_id, scoring_fn_params
|
|
||||||
)
|
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -32,9 +32,7 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
|
||||||
scoring_params: Optional[ScoringFnParams] = None,
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
) -> ScoringResultRow:
|
) -> ScoringResultRow:
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||||
assert (
|
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
||||||
"generated_answer" in input_row
|
|
||||||
), "Generated answer not found in input row."
|
|
||||||
|
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
|
@ -18,7 +18,5 @@ equality = ScoringFn(
|
||||||
provider_id="basic",
|
provider_id="basic",
|
||||||
provider_resource_id="equality",
|
provider_resource_id="equality",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
params=BasicScoringFnParams(
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -55,9 +55,7 @@ MULTILINGUAL_ANSWER_REGEXES = [
|
||||||
r"Àṣàyàn\s*:",
|
r"Àṣàyàn\s*:",
|
||||||
]
|
]
|
||||||
|
|
||||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||||
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
|
||||||
)
|
|
||||||
|
|
||||||
regex_parser_multiple_choice_answer = ScoringFn(
|
regex_parser_multiple_choice_answer = ScoringFn(
|
||||||
identifier="basic::regex_parser_multiple_choice_answer",
|
identifier="basic::regex_parser_multiple_choice_answer",
|
||||||
|
@ -66,10 +64,7 @@ regex_parser_multiple_choice_answer = ScoringFn(
|
||||||
provider_id="basic",
|
provider_id="basic",
|
||||||
provider_resource_id="regex-parser-multiple-choice-answer",
|
provider_resource_id="regex-parser-multiple-choice-answer",
|
||||||
params=RegexParserScoringFnParams(
|
params=RegexParserScoringFnParams(
|
||||||
parsing_regexes=[
|
parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
|
||||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
|
||||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
|
||||||
],
|
|
||||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,5 @@ subset_of = ScoringFn(
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
provider_id="basic",
|
provider_id="basic",
|
||||||
provider_resource_id="subset-of",
|
provider_resource_id="subset-of",
|
||||||
params=BasicScoringFnParams(
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,17 +33,14 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
scoring_fn_identifier: Optional[str] = None,
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
scoring_params: Optional[ScoringFnParams] = None,
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
) -> ScoringResultRow:
|
) -> ScoringResultRow:
|
||||||
assert (
|
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||||
scoring_fn_identifier is not None
|
|
||||||
), "Scoring function identifier not found."
|
|
||||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
if scoring_params is not None:
|
if scoring_params is not None:
|
||||||
fn_def.params = scoring_params
|
fn_def.params = scoring_params
|
||||||
|
|
||||||
assert (
|
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||||
fn_def.params is not None
|
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||||
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
|
)
|
||||||
), f"RegexParserScoringFnParams not found for {fn_def}."
|
|
||||||
|
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
|
@ -124,12 +124,10 @@ class BraintrustScoringImpl(
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
|
|
||||||
self.braintrust_evaluators = {
|
self.braintrust_evaluators = {
|
||||||
entry.identifier: entry.evaluator
|
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
|
||||||
}
|
}
|
||||||
self.supported_fn_defs_registry = {
|
self.supported_fn_defs_registry = {
|
||||||
entry.identifier: entry.fn_def
|
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
@ -139,16 +137,14 @@ class BraintrustScoringImpl(
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||||
for f in scoring_fn_defs_list:
|
for f in scoring_fn_defs_list:
|
||||||
assert f.identifier.startswith(
|
assert f.identifier.startswith("braintrust"), (
|
||||||
"braintrust"
|
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||||
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
)
|
||||||
|
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
|
||||||
"Registering scoring function not allowed for braintrust provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_api_key(self) -> None:
|
async def set_api_key(self) -> None:
|
||||||
# api key is in the request headers
|
# api key is in the request headers
|
||||||
|
@ -171,17 +167,13 @@ class BraintrustScoringImpl(
|
||||||
await self.set_api_key()
|
await self.set_api_key()
|
||||||
|
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
res = await self.score(
|
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
|
||||||
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
|
||||||
)
|
|
||||||
if save_results_dataset:
|
if save_results_dataset:
|
||||||
# TODO: persist and register dataset on to server for reading
|
# TODO: persist and register dataset on to server for reading
|
||||||
# self.datasets_api.register_dataset()
|
# self.datasets_api.register_dataset()
|
||||||
|
@ -222,13 +214,8 @@ class BraintrustScoringImpl(
|
||||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
|
|
||||||
score_results = [
|
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
|
||||||
await self.score_row(input_row, scoring_fn_id)
|
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
|
||||||
for input_row in input_rows
|
|
||||||
]
|
|
||||||
aggregation_functions = self.supported_fn_defs_registry[
|
|
||||||
scoring_fn_id
|
|
||||||
].params.aggregation_functions
|
|
||||||
|
|
||||||
# override scoring_fn params if provided
|
# override scoring_fn params if provided
|
||||||
if scoring_functions[scoring_fn_id] is not None:
|
if scoring_functions[scoring_fn_id] is not None:
|
||||||
|
|
|
@ -21,7 +21,5 @@ answer_correctness_fn_def = ScoringFn(
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_resource_id="answer-correctness",
|
provider_resource_id="answer-correctness",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
params=BasicScoringFnParams(
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||||
aggregation_functions=[AggregationFunctionType.average]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,5 @@ answer_relevancy_fn_def = ScoringFn(
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_resource_id="answer-relevancy",
|
provider_resource_id="answer-relevancy",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
params=BasicScoringFnParams(
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||||
aggregation_functions=[AggregationFunctionType.average]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,5 @@ answer_similarity_fn_def = ScoringFn(
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_resource_id="answer-similarity",
|
provider_resource_id="answer-similarity",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
params=BasicScoringFnParams(
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||||
aggregation_functions=[AggregationFunctionType.average]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue