Merge remote-tracking branch 'upstream/main' into add_nvidia_safety_provider

Merging upstream changes
This commit is contained in:
Chantal D Gama Rose 2025-02-21 00:39:45 +00:00
commit 78b1105f5d
112 changed files with 5112 additions and 3313 deletions

View file

@ -75,19 +75,20 @@ repos:
# - id: markdown-link-check
# args: ['--quiet']
# - repo: local
# hooks:
# - id: distro-codegen
# name: Distribution Template Codegen
# additional_dependencies:
# - rich
# - pydantic
# entry: python -m llama_stack.scripts.distro_codegen
# language: python
# pass_filenames: false
# require_serial: true
# files: ^llama_stack/templates/.*$
# stages: [manual]
- repo: local
hooks:
- id: distro-codegen
name: Distribution Template Codegen
additional_dependencies:
- rich
- pydantic
- uv==0.6.0
entry: uv run python -m llama_stack.scripts.distro_codegen
language: python
pass_filenames: false
require_serial: true
files: ^llama_stack/templates/.*$
files: ^llama_stack/providers/.*/inference/.*/models\.py$
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -139,6 +139,16 @@ $ make html
$ uv run sphinx-autobuild source build/html
```
### Update API Documentation
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
```bash
$ uv sync --extra dev
$ ./docs/openapi_generator/run_openapi_generator.sh
```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
## License
By contributing to Llama, you agree that your contributions will be licensed

View file

@ -21,6 +21,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -54,6 +55,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -88,6 +90,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -122,6 +125,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -157,6 +161,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -192,6 +197,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -228,6 +234,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -269,6 +276,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -306,6 +314,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -340,6 +349,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -373,6 +383,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -403,6 +414,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -438,6 +450,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -471,6 +484,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
@ -505,6 +519,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1017,14 +1017,14 @@
" \"content\": SYSTEM_PROMPT_TEMPLATE.format(subject=subset),\n",
"}\n",
"\n",
"client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::mmmu\",\n",
"client.benchmarks.register(\n",
" benchmark_id=\"meta-reference::mmmu\",\n",
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows(\n",
" task_id=\"meta-reference::mmmu\",\n",
"response = client.eval.evaluate_rows_alpha(\n",
" benchmark_id=\"meta-reference::mmmu\",\n",
" input_rows=eval_rows,\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
" task_config={\n",
@ -1196,14 +1196,14 @@
" provider_id=\"together\",\n",
")\n",
"\n",
"client.eval_tasks.register(\n",
" eval_task_id=\"meta-reference::simpleqa\",\n",
"client.benchmarks.register(\n",
" benchmark_id=\"meta-reference::simpleqa\",\n",
" dataset_id=simpleqa_dataset_id,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
")\n",
"\n",
"response = client.eval.evaluate_rows(\n",
" task_id=\"meta-reference::simpleqa\",\n",
"response = client.eval.evaluate_rows_alpha(\n",
" benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.rows,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" task_config={\n",
@ -1351,8 +1351,8 @@
" \"enable_session_persistence\": False,\n",
"}\n",
"\n",
"response = client.eval.evaluate_rows(\n",
" task_id=\"meta-reference::simpleqa\",\n",
"response = client.eval.evaluate_rows_alpha(\n",
" benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.rows,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" task_config={\n",

View file

@ -3,7 +3,7 @@ The RFC Specification (OpenAPI format) is generated from the set of API endpoint
Please install the following packages before running the script:
```
pip install python-openapi json-strong-typing fire PyYAML llama-models
pip install fire PyYAML llama-models
```
Then simply run `sh run_openapi_generator.sh`

View file

@ -477,6 +477,7 @@ class Generator:
"SyntheticDataGeneration",
"PostTraining",
"BatchInference",
"Files",
]:
op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
print(op.defining_class.__name__)
@ -520,8 +521,30 @@ class Generator:
# parameters passed anywhere
parameters = path_parameters + query_parameters
# data passed in payload
if op.request_params:
webmethod = getattr(op.func_ref, "__webmethod__", None)
raw_bytes_request_body = False
if webmethod:
raw_bytes_request_body = getattr(webmethod, "raw_bytes_request_body", False)
# data passed in request body as raw bytes cannot have request parameters
if raw_bytes_request_body and op.request_params:
raise ValueError("Cannot have both raw bytes request body and request parameters")
# data passed in request body as raw bytes
if raw_bytes_request_body:
requestBody = RequestBody(
content={
"application/octet-stream": {
"schema": {
"type": "string",
"format": "binary",
}
}
},
required=True,
)
# data passed in payload as JSON and mapped to request parameters
elif op.request_params:
builder = ContentBuilder(self.schema_builder)
first = next(iter(op.request_params))
request_name, request_type = first

View file

@ -150,7 +150,14 @@ def _get_endpoint_functions(
print(f"Processing {colored(func_name, 'white')}...")
operation_name = func_name
if operation_name.startswith("get_") or operation_name.endswith("/get"):
if webmethod.method == "GET":
prefix = "get"
elif webmethod.method == "DELETE":
prefix = "delete"
elif webmethod.method == "POST":
prefix = "post"
elif operation_name.startswith("get_") or operation_name.endswith("/get"):
prefix = "get"
elif (
operation_name.startswith("delete_")
@ -160,13 +167,8 @@ def _get_endpoint_functions(
):
prefix = "delete"
else:
if webmethod.method == "GET":
prefix = "get"
elif webmethod.method == "DELETE":
prefix = "delete"
else:
# by default everything else is a POST
prefix = "post"
# by default everything else is a POST
prefix = "post"
yield prefix, operation_name, func_name, func_ref

View file

@ -78,7 +78,7 @@ class MediaType:
@dataclass
class RequestBody:
content: Dict[str, MediaType]
content: Dict[str, MediaType | Dict[str, Any]]
description: Optional[str] = None
required: Optional[bool] = None

View file

@ -1,4 +1,4 @@
sphinx
sphinx==8.1.3
myst-parser
linkify
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

View file

@ -122,3 +122,20 @@ response = agent.create_turn(
session_id=session_id,
)
```
### Unregistering Vector DBs
If you need to clean up and unregister vector databases, you can do so as follows:
```python
# Unregister a specified vector database
vector_db_id = "my_vector_db_id"
print(f"Unregistering vector database: {vector_db_id}")
client.vector_dbs.unregister(vector_db_id=vector_db_id)
# Unregister all vector databases
for vector_db_id in client.vector_dbs.list():
print(f"Unregistering vector database: {vector_db_id.identifier}")
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
```

View file

@ -60,6 +60,11 @@ Features:
- Disabled dangerous system operations
- Configurable execution timeouts
> ⚠️ Important: The code interpreter tool can operate in a controlled enviroment locally or on Podman containers. To ensure proper functionality in containerised environments:
> - The container requires privileged access (e.g., --privileged).
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
#### WolframAlpha
The WolframAlpha tool provides access to computational knowledge through the WolframAlpha API.
@ -103,7 +108,7 @@ Features:
MCP tools are special tools that can interact with llama stack over model context protocol. These tools are dynamically discovered from an MCP endpoint and can be used to extend the agent's capabilities.
Refer to https://github.com/modelcontextprotocol/server for available MCP servers.
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
```python
# Register MCP tools
@ -191,3 +196,36 @@ all_tools = client.tools.list_tools()
# List tools in a specific group
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
```
## Simple Example: Using an Agent with the Code-Interpreter Tool
```python
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.types.agent_create_params import AgentConfig
# Configure the AI agent with necessary parameters
agent_config = AgentConfig(
name="code-interpreter",
description="A code interpreter agent for executing Python code snippets",
instructions="""
You are a highly reliable, concise, and precise assistant.
Always show the generated code, never generate your own code, and never anticipate results.
""",
model="meta-llama/Llama-3.2-3B-Instruct",
toolgroups=["builtin::code_interpreter"],
max_infer_iters=5,
enable_session_persistence=False,
)
# Instantiate the AI agent with the given configuration
agent = Agent(client, agent_config)
# Start a session
session_id = agent.create_session("tool_session")
# Send a query to the AI agent for code execution
response = agent.create_turn(
messages=[{"role": "user", "content": "Run this code: print(3 ** 4 - 5 * 2)"}],
session_id=session_id,
)
```

View file

@ -13,6 +13,7 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
- **DatasetIO**: interface with datasets and data loaders
- **Scoring**: evaluate outputs of the system
- **Eval**: generate outputs (via Inference or Agents) and perform scoring
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
- **Telemetry**: collect telemetry data from the system
We are working on adding a few more APIs to complete the application lifecycle. These will include:
@ -41,6 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi
- **Safety** is associated with `Shield` resources.
- **Tool Runtime** is associated with `ToolGroup` resources.
- **DatasetIO** is associated with `Dataset` resources.
- **VectorIO** is associated with `VectorDB` resources.
- **Scoring** is associated with `ScoringFunction` resources.
- **Eval** is associated with `Model` and `Benchmark` resources.

View file

@ -93,9 +93,10 @@ html_theme_options = {
html_static_path = ["../_static"]
# html_logo = "../_static/llama-stack-logo.png"
html_style = "../_static/css/my_theme.css"
# html_style = "../_static/css/my_theme.css"
def setup(app):
app.add_css_file("css/my_theme.css")
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
url = f"https://hub.docker.com/r/llamastack/{text}"
node = nodes.reference(rawtext, text, refuri=url, **options)

View file

@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`llama_stack/scripts/distro_codegen.py` if necessary.
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`llama_stack/scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
Here are some example PRs to help you get started:

View file

@ -10,7 +10,7 @@ conda_env: ollama
apis:
- agents
- inference
- memory
- vector_io
- safety
- telemetry
providers:
@ -19,7 +19,7 @@ providers:
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
memory:
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:

View file

@ -61,7 +61,8 @@ docker run \
--port $LLAMA_STACK_PORT \
--env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
--env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN \
--env AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION
```
### Via Conda
@ -72,5 +73,6 @@ llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
--env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN \
--env AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION
```

View file

@ -47,6 +47,7 @@ The following models are available by default:
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)`
### Prerequisite: API Keys

View file

@ -130,7 +130,7 @@ llama stack run ./run-with-safety.yaml \
### (Optional) Update Model Serving Configuration
```{note}
Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
```
To serve a new model with `ollama`

View file

@ -46,6 +46,8 @@ The following models are available by default:
- `meta-llama/Llama-3.3-70B-Instruct`
- `meta-llama/Llama-Guard-3-8B`
- `meta-llama/Llama-Guard-3-11B-Vision`
- `togethercomputer/m2-bert-80M-8k-retrieval`
- `togethercomputer/m2-bert-80M-32k-retrieval`
### Prerequisite: API Keys

View file

@ -214,10 +214,16 @@ documents = [
for i, url in enumerate(urls)
]
vector_providers = [
provider for provider in client.providers.list() if provider.api == "vector_io"
]
provider_id = vector_providers[0].provider_id # Use the first available vector provider
# Register a vector database
vector_db_id = f"test-vector-db-{uuid.uuid4().hex}"
client.vector_dbs.register(
vector_db_id=vector_db_id,
provider_id=provider_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)

View file

@ -42,10 +42,10 @@ from llama_stack_client.types import (
Methods:
- <code title="get /v1/toolgroups">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">list</a>() -> <a href="./src/llama_stack_client/types/toolgroup_list_response.py">ToolgroupListResponse</a></code>
- <code title="get /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">get</a>(toolgroup_id) -> <a href="./src/llama_stack_client/types/tool_group.py">ToolGroup</a></code>
- <code title="post /v1/toolgroups">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">register</a>(\*\*<a href="src/llama_stack_client/types/toolgroup_register_params.py">params</a>) -> None</code>
- <code title="delete /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">unregister</a>(toolgroup_id) -> None</code>
- <code title="get /v1/toolgroups">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/toolgroup_list_response.py">ToolgroupListResponse</a></code>
- <code title="get /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">get</a>(toolgroup_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_group.py">ToolGroup</a></code>
- <code title="post /v1/toolgroups">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/toolgroup_register_params.py">params</a>) -> None</code>
- <code title="delete /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">unregister</a>(toolgroup_id) -> None</code>
## Tools
@ -57,8 +57,8 @@ from llama_stack_client.types import ListToolsResponse, Tool, ToolListResponse
Methods:
- <code title="get /v1/tools">client.tools.<a href="./src/llama_stack_client/resources/tools.py">list</a>(\*\*<a href="src/llama_stack_client/types/tool_list_params.py">params</a>) -> <a href="./src/llama_stack_client/types/tool_list_response.py">ToolListResponse</a></code>
- <code title="get /v1/tools/{tool_name}">client.tools.<a href="./src/llama_stack_client/resources/tools.py">get</a>(tool_name) -> <a href="./src/llama_stack_client/types/tool.py">Tool</a></code>
- <code title="get /v1/tools">client.tools.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tools.py">list</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_list_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_list_response.py">ToolListResponse</a></code>
- <code title="get /v1/tools/{tool_name}">client.tools.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tools.py">get</a>(tool_name) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool.py">Tool</a></code>
## ToolRuntime
@ -70,15 +70,15 @@ from llama_stack_client.types import ToolDef, ToolInvocationResult
Methods:
- <code title="post /v1/tool-runtime/invoke">client.tool_runtime.<a href="./src/llama_stack_client/resources/tool_runtime/tool_runtime.py">invoke_tool</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime_invoke_tool_params.py">params</a>) -> <a href="./src/llama_stack_client/types/tool_invocation_result.py">ToolInvocationResult</a></code>
- <code title="get /v1/tool-runtime/list-tools">client.tool_runtime.<a href="./src/llama_stack_client/resources/tool_runtime/tool_runtime.py">list_tools</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime_list_tools_params.py">params</a>) -> <a href="./src/llama_stack_client/types/tool_def.py">JSONLDecoder[ToolDef]</a></code>
- <code title="post /v1/tool-runtime/invoke">client.tool_runtime.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/tool_runtime.py">invoke_tool</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime_invoke_tool_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_invocation_result.py">ToolInvocationResult</a></code>
- <code title="get /v1/tool-runtime/list-tools">client.tool_runtime.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/tool_runtime.py">list_tools</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime_list_tools_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_def.py">JSONLDecoder[ToolDef]</a></code>
### RagTool
Methods:
- <code title="post /v1/tool-runtime/rag-tool/insert">client.tool_runtime.rag_tool.<a href="./src/llama_stack_client/resources/tool_runtime/rag_tool.py">insert</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime/rag_tool_insert_params.py">params</a>) -> None</code>
- <code title="post /v1/tool-runtime/rag-tool/query">client.tool_runtime.rag_tool.<a href="./src/llama_stack_client/resources/tool_runtime/rag_tool.py">query</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime/rag_tool_query_params.py">params</a>) -> <a href="./src/llama_stack_client/types/shared/query_result.py">QueryResult</a></code>
- <code title="post /v1/tool-runtime/rag-tool/insert">client.tool_runtime.rag_tool.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/rag_tool.py">insert</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime/rag_tool_insert_params.py">params</a>) -> None</code>
- <code title="post /v1/tool-runtime/rag-tool/query">client.tool_runtime.rag_tool.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/rag_tool.py">query</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime/rag_tool_query_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shared/query_result.py">QueryResult</a></code>
## Agents
@ -97,8 +97,8 @@ from llama_stack_client.types import (
Methods:
- <code title="post /v1/agents">client.agents.<a href="./src/llama_stack_client/resources/agents/agents.py">create</a>(\*\*<a href="src/llama_stack_client/types/agent_create_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agent_create_response.py">AgentCreateResponse</a></code>
- <code title="delete /v1/agents/{agent_id}">client.agents.<a href="./src/llama_stack_client/resources/agents/agents.py">delete</a>(agent_id) -> None</code>
- <code title="post /v1/agents">client.agents.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/agents.py">create</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agent_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agent_create_response.py">AgentCreateResponse</a></code>
- <code title="delete /v1/agents/{agent_id}">client.agents.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/agents.py">delete</a>(agent_id) -> None</code>
### Session
@ -110,9 +110,9 @@ from llama_stack_client.types.agents import Session, SessionCreateResponse
Methods:
- <code title="post /v1/agents/{agent_id}/session">client.agents.session.<a href="./src/llama_stack_client/resources/agents/session.py">create</a>(agent_id, \*\*<a href="src/llama_stack_client/types/agents/session_create_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agents/session_create_response.py">SessionCreateResponse</a></code>
- <code title="get /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="./src/llama_stack_client/resources/agents/session.py">retrieve</a>(session_id, \*, agent_id, \*\*<a href="src/llama_stack_client/types/agents/session_retrieve_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agents/session.py">Session</a></code>
- <code title="delete /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="./src/llama_stack_client/resources/agents/session.py">delete</a>(session_id, \*, agent_id) -> None</code>
- <code title="post /v1/agents/{agent_id}/session">client.agents.session.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/session.py">create</a>(agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session_create_response.py">SessionCreateResponse</a></code>
- <code title="get /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/session.py">retrieve</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session_retrieve_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session.py">Session</a></code>
- <code title="delete /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/session.py">delete</a>(session_id, \*, agent_id) -> None</code>
### Steps
@ -124,7 +124,7 @@ from llama_stack_client.types.agents import StepRetrieveResponse
Methods:
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}">client.agents.steps.<a href="./src/llama_stack_client/resources/agents/steps.py">retrieve</a>(step_id, \*, agent_id, session_id, turn_id) -> <a href="./src/llama_stack_client/types/agents/step_retrieve_response.py">StepRetrieveResponse</a></code>
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}">client.agents.steps.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/steps.py">retrieve</a>(step_id, \*, agent_id, session_id, turn_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/step_retrieve_response.py">StepRetrieveResponse</a></code>
### Turn
@ -136,8 +136,8 @@ from llama_stack_client.types.agents import Turn, TurnCreateResponse
Methods:
- <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="./src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code>
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="./src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="./src/llama_stack_client/types/agents/turn.py">Turn</a></code>
- <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code>
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn.py">Turn</a></code>
## BatchInference
@ -149,8 +149,8 @@ from llama_stack_client.types import BatchInferenceChatCompletionResponse
Methods:
- <code title="post /v1/batch-inference/chat-completion">client.batch_inference.<a href="./src/llama_stack_client/resources/batch_inference.py">chat_completion</a>(\*\*<a href="src/llama_stack_client/types/batch_inference_chat_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/batch_inference_chat_completion_response.py">BatchInferenceChatCompletionResponse</a></code>
- <code title="post /v1/batch-inference/completion">client.batch_inference.<a href="./src/llama_stack_client/resources/batch_inference.py">completion</a>(\*\*<a href="src/llama_stack_client/types/batch_inference_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/shared/batch_completion.py">BatchCompletion</a></code>
- <code title="post /v1/batch-inference/chat-completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">chat_completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_response.py">BatchInferenceChatCompletionResponse</a></code>
- <code title="post /v1/batch-inference/completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shared/batch_completion.py">BatchCompletion</a></code>
## Datasets
@ -166,10 +166,10 @@ from llama_stack_client.types import (
Methods:
- <code title="get /v1/datasets/{dataset_id}">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">retrieve</a>(dataset_id) -> <a href="./src/llama_stack_client/types/dataset_retrieve_response.py">Optional[DatasetRetrieveResponse]</a></code>
- <code title="get /v1/datasets">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">list</a>() -> <a href="./src/llama_stack_client/types/dataset_list_response.py">DatasetListResponse</a></code>
- <code title="post /v1/datasets">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">register</a>(\*\*<a href="src/llama_stack_client/types/dataset_register_params.py">params</a>) -> None</code>
- <code title="delete /v1/datasets/{dataset_id}">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">unregister</a>(dataset_id) -> None</code>
- <code title="get /v1/datasets/{dataset_id}">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">retrieve</a>(dataset_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/dataset_retrieve_response.py">Optional[DatasetRetrieveResponse]</a></code>
- <code title="get /v1/datasets">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/dataset_list_response.py">DatasetListResponse</a></code>
- <code title="post /v1/datasets">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/dataset_register_params.py">params</a>) -> None</code>
- <code title="delete /v1/datasets/{dataset_id}">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">unregister</a>(dataset_id) -> None</code>
## Eval
@ -181,8 +181,8 @@ from llama_stack_client.types import EvaluateResponse, Job
Methods:
- <code title="post /v1/eval/tasks/{benchmark_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="post /v1/eval/tasks/{benchmark_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
- <code title="post /v1/eval/tasks/{benchmark_id}/evaluations">client.eval.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(benchmark_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="post /v1/eval/tasks/{benchmark_id}/jobs">client.eval.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/eval.py">run_eval</a>(benchmark_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/job.py">Job</a></code>
### Jobs
@ -194,9 +194,9 @@ from llama_stack_client.types.eval import JobStatusResponse
Methods:
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, benchmark_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="delete /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, benchmark_id) -> None</code>
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, benchmark_id) -> Optional[JobStatusResponse]</code>
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}/result">client.eval.jobs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, benchmark_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="delete /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, benchmark_id) -> None</code>
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, benchmark_id) -> Optional[JobStatusResponse]</code>
## Inspect
@ -208,8 +208,8 @@ from llama_stack_client.types import HealthInfo, ProviderInfo, RouteInfo, Versio
Methods:
- <code title="get /v1/health">client.inspect.<a href="./src/llama_stack_client/resources/inspect.py">health</a>() -> <a href="./src/llama_stack_client/types/health_info.py">HealthInfo</a></code>
- <code title="get /v1/version">client.inspect.<a href="./src/llama_stack_client/resources/inspect.py">version</a>() -> <a href="./src/llama_stack_client/types/version_info.py">VersionInfo</a></code>
- <code title="get /v1/health">client.inspect.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inspect.py">health</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/health_info.py">HealthInfo</a></code>
- <code title="get /v1/version">client.inspect.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inspect.py">version</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/version_info.py">VersionInfo</a></code>
## Inference
@ -227,9 +227,9 @@ from llama_stack_client.types import (
Methods:
- <code title="post /v1/inference/chat-completion">client.inference.<a href="./src/llama_stack_client/resources/inference.py">chat_completion</a>(\*\*<a href="src/llama_stack_client/types/inference_chat_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/inference_chat_completion_response.py">InferenceChatCompletionResponse</a></code>
- <code title="post /v1/inference/completion">client.inference.<a href="./src/llama_stack_client/resources/inference.py">completion</a>(\*\*<a href="src/llama_stack_client/types/inference_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/inference_completion_response.py">InferenceCompletionResponse</a></code>
- <code title="post /v1/inference/embeddings">client.inference.<a href="./src/llama_stack_client/resources/inference.py">embeddings</a>(\*\*<a href="src/llama_stack_client/types/inference_embeddings_params.py">params</a>) -> <a href="./src/llama_stack_client/types/embeddings_response.py">EmbeddingsResponse</a></code>
- <code title="post /v1/inference/chat-completion">client.inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inference.py">chat_completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_chat_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_chat_completion_response.py">InferenceChatCompletionResponse</a></code>
- <code title="post /v1/inference/completion">client.inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inference.py">completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_completion_response.py">InferenceCompletionResponse</a></code>
- <code title="post /v1/inference/embeddings">client.inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inference.py">embeddings</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_embeddings_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/embeddings_response.py">EmbeddingsResponse</a></code>
## VectorIo
@ -241,8 +241,8 @@ from llama_stack_client.types import QueryChunksResponse
Methods:
- <code title="post /v1/vector-io/insert">client.vector_io.<a href="./src/llama_stack_client/resources/vector_io.py">insert</a>(\*\*<a href="src/llama_stack_client/types/vector_io_insert_params.py">params</a>) -> None</code>
- <code title="post /v1/vector-io/query">client.vector_io.<a href="./src/llama_stack_client/resources/vector_io.py">query</a>(\*\*<a href="src/llama_stack_client/types/vector_io_query_params.py">params</a>) -> <a href="./src/llama_stack_client/types/query_chunks_response.py">QueryChunksResponse</a></code>
- <code title="post /v1/vector-io/insert">client.vector_io.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_io.py">insert</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_io_insert_params.py">params</a>) -> None</code>
- <code title="post /v1/vector-io/query">client.vector_io.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_io.py">query</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_io_query_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/query_chunks_response.py">QueryChunksResponse</a></code>
## VectorDBs
@ -259,10 +259,10 @@ from llama_stack_client.types import (
Methods:
- <code title="get /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">retrieve</a>(vector_db_id) -> <a href="./src/llama_stack_client/types/vector_db_retrieve_response.py">Optional[VectorDBRetrieveResponse]</a></code>
- <code title="get /v1/vector-dbs">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">list</a>() -> <a href="./src/llama_stack_client/types/vector_db_list_response.py">VectorDBListResponse</a></code>
- <code title="post /v1/vector-dbs">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">register</a>(\*\*<a href="src/llama_stack_client/types/vector_db_register_params.py">params</a>) -> <a href="./src/llama_stack_client/types/vector_db_register_response.py">VectorDBRegisterResponse</a></code>
- <code title="delete /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">unregister</a>(vector_db_id) -> None</code>
- <code title="get /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">retrieve</a>(vector_db_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_retrieve_response.py">Optional[VectorDBRetrieveResponse]</a></code>
- <code title="get /v1/vector-dbs">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_list_response.py">VectorDBListResponse</a></code>
- <code title="post /v1/vector-dbs">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_register_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_register_response.py">VectorDBRegisterResponse</a></code>
- <code title="delete /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">unregister</a>(vector_db_id) -> None</code>
## Models
@ -274,10 +274,10 @@ from llama_stack_client.types import ListModelsResponse, Model, ModelListRespons
Methods:
- <code title="get /v1/models/{model_id}">client.models.<a href="./src/llama_stack_client/resources/models.py">retrieve</a>(model_id) -> <a href="./src/llama_stack_client/types/model.py">Optional[Model]</a></code>
- <code title="get /v1/models">client.models.<a href="./src/llama_stack_client/resources/models.py">list</a>() -> <a href="./src/llama_stack_client/types/model_list_response.py">ModelListResponse</a></code>
- <code title="post /v1/models">client.models.<a href="./src/llama_stack_client/resources/models.py">register</a>(\*\*<a href="src/llama_stack_client/types/model_register_params.py">params</a>) -> <a href="./src/llama_stack_client/types/model.py">Model</a></code>
- <code title="delete /v1/models/{model_id}">client.models.<a href="./src/llama_stack_client/resources/models.py">unregister</a>(model_id) -> None</code>
- <code title="get /v1/models/{model_id}">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">retrieve</a>(model_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model.py">Optional[Model]</a></code>
- <code title="get /v1/models">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model_list_response.py">ModelListResponse</a></code>
- <code title="post /v1/models">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model_register_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model.py">Model</a></code>
- <code title="delete /v1/models/{model_id}">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">unregister</a>(model_id) -> None</code>
## PostTraining
@ -289,8 +289,8 @@ from llama_stack_client.types import ListPostTrainingJobsResponse, PostTrainingJ
Methods:
- <code title="post /v1/post-training/preference-optimize">client.post_training.<a href="./src/llama_stack_client/resources/post_training/post_training.py">preference_optimize</a>(\*\*<a href="src/llama_stack_client/types/post_training_preference_optimize_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
- <code title="post /v1/post-training/supervised-fine-tune">client.post_training.<a href="./src/llama_stack_client/resources/post_training/post_training.py">supervised_fine_tune</a>(\*\*<a href="src/llama_stack_client/types/post_training_supervised_fine_tune_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
- <code title="post /v1/post-training/preference-optimize">client.post_training.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/post_training.py">preference_optimize</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_preference_optimize_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
- <code title="post /v1/post-training/supervised-fine-tune">client.post_training.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/post_training.py">supervised_fine_tune</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
### Job
@ -306,10 +306,10 @@ from llama_stack_client.types.post_training import (
Methods:
- <code title="get /v1/post-training/jobs">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">list</a>() -> <a href="./src/llama_stack_client/types/post_training/job_list_response.py">JobListResponse</a></code>
- <code title="get /v1/post-training/job/artifacts">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">artifacts</a>(\*\*<a href="src/llama_stack_client/types/post_training/job_artifacts_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training/job_artifacts_response.py">Optional[JobArtifactsResponse]</a></code>
- <code title="post /v1/post-training/job/cancel">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">cancel</a>(\*\*<a href="src/llama_stack_client/types/post_training/job_cancel_params.py">params</a>) -> None</code>
- <code title="get /v1/post-training/job/status">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">status</a>(\*\*<a href="src/llama_stack_client/types/post_training/job_status_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training/job_status_response.py">Optional[JobStatusResponse]</a></code>
- <code title="get /v1/post-training/jobs">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_list_response.py">JobListResponse</a></code>
- <code title="get /v1/post-training/job/artifacts">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">artifacts</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_artifacts_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_artifacts_response.py">Optional[JobArtifactsResponse]</a></code>
- <code title="post /v1/post-training/job/cancel">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">cancel</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_cancel_params.py">params</a>) -> None</code>
- <code title="get /v1/post-training/job/status">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">status</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_status_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_status_response.py">Optional[JobStatusResponse]</a></code>
## Providers
@ -321,7 +321,7 @@ from llama_stack_client.types import ListProvidersResponse, ProviderListResponse
Methods:
- <code title="get /v1/inspect/providers">client.providers.<a href="./src/llama_stack_client/resources/providers.py">list</a>() -> <a href="./src/llama_stack_client/types/provider_list_response.py">ProviderListResponse</a></code>
- <code title="get /v1/inspect/providers">client.providers.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/providers.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/provider_list_response.py">ProviderListResponse</a></code>
## Routes
@ -333,7 +333,7 @@ from llama_stack_client.types import ListRoutesResponse, RouteListResponse
Methods:
- <code title="get /v1/inspect/routes">client.routes.<a href="./src/llama_stack_client/resources/routes.py">list</a>() -> <a href="./src/llama_stack_client/types/route_list_response.py">RouteListResponse</a></code>
- <code title="get /v1/inspect/routes">client.routes.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/routes.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/route_list_response.py">RouteListResponse</a></code>
## Safety
@ -345,7 +345,7 @@ from llama_stack_client.types import RunShieldResponse
Methods:
- <code title="post /v1/safety/run-shield">client.safety.<a href="./src/llama_stack_client/resources/safety.py">run_shield</a>(\*\*<a href="src/llama_stack_client/types/safety_run_shield_params.py">params</a>) -> <a href="./src/llama_stack_client/types/run_shield_response.py">RunShieldResponse</a></code>
- <code title="post /v1/safety/run-shield">client.safety.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/safety.py">run_shield</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/safety_run_shield_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/run_shield_response.py">RunShieldResponse</a></code>
## Shields
@ -357,9 +357,9 @@ from llama_stack_client.types import ListShieldsResponse, Shield, ShieldListResp
Methods:
- <code title="get /v1/shields/{identifier}">client.shields.<a href="./src/llama_stack_client/resources/shields.py">retrieve</a>(identifier) -> <a href="./src/llama_stack_client/types/shield.py">Optional[Shield]</a></code>
- <code title="get /v1/shields">client.shields.<a href="./src/llama_stack_client/resources/shields.py">list</a>() -> <a href="./src/llama_stack_client/types/shield_list_response.py">ShieldListResponse</a></code>
- <code title="post /v1/shields">client.shields.<a href="./src/llama_stack_client/resources/shields.py">register</a>(\*\*<a href="src/llama_stack_client/types/shield_register_params.py">params</a>) -> <a href="./src/llama_stack_client/types/shield.py">Shield</a></code>
- <code title="get /v1/shields/{identifier}">client.shields.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/shields.py">retrieve</a>(identifier) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield.py">Optional[Shield]</a></code>
- <code title="get /v1/shields">client.shields.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/shields.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield_list_response.py">ShieldListResponse</a></code>
- <code title="post /v1/shields">client.shields.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/shields.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield_register_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield.py">Shield</a></code>
## SyntheticDataGeneration
@ -371,7 +371,7 @@ from llama_stack_client.types import SyntheticDataGenerationResponse
Methods:
- <code title="post /v1/synthetic-data-generation/generate">client.synthetic_data_generation.<a href="./src/llama_stack_client/resources/synthetic_data_generation.py">generate</a>(\*\*<a href="src/llama_stack_client/types/synthetic_data_generation_generate_params.py">params</a>) -> <a href="./src/llama_stack_client/types/synthetic_data_generation_response.py">SyntheticDataGenerationResponse</a></code>
- <code title="post /v1/synthetic-data-generation/generate">client.synthetic_data_generation.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/synthetic_data_generation.py">generate</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/synthetic_data_generation_generate_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/synthetic_data_generation_response.py">SyntheticDataGenerationResponse</a></code>
## Telemetry
@ -391,13 +391,13 @@ from llama_stack_client.types import (
Methods:
- <code title="get /v1/telemetry/traces/{trace_id}/spans/{span_id}">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">get_span</a>(span_id, \*, trace_id) -> <a href="./src/llama_stack_client/types/telemetry_get_span_response.py">TelemetryGetSpanResponse</a></code>
- <code title="get /v1/telemetry/spans/{span_id}/tree">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">get_span_tree</a>(span_id, \*\*<a href="src/llama_stack_client/types/telemetry_get_span_tree_params.py">params</a>) -> <a href="./src/llama_stack_client/types/telemetry_get_span_tree_response.py">TelemetryGetSpanTreeResponse</a></code>
- <code title="get /v1/telemetry/traces/{trace_id}">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">get_trace</a>(trace_id) -> <a href="./src/llama_stack_client/types/trace.py">Trace</a></code>
- <code title="post /v1/telemetry/events">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">log_event</a>(\*\*<a href="src/llama_stack_client/types/telemetry_log_event_params.py">params</a>) -> None</code>
- <code title="get /v1/telemetry/spans">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">query_spans</a>(\*\*<a href="src/llama_stack_client/types/telemetry_query_spans_params.py">params</a>) -> <a href="./src/llama_stack_client/types/telemetry_query_spans_response.py">TelemetryQuerySpansResponse</a></code>
- <code title="get /v1/telemetry/traces">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">query_traces</a>(\*\*<a href="src/llama_stack_client/types/telemetry_query_traces_params.py">params</a>) -> <a href="./src/llama_stack_client/types/telemetry_query_traces_response.py">TelemetryQueryTracesResponse</a></code>
- <code title="post /v1/telemetry/spans/export">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">save_spans_to_dataset</a>(\*\*<a href="src/llama_stack_client/types/telemetry_save_spans_to_dataset_params.py">params</a>) -> None</code>
- <code title="get /v1/telemetry/traces/{trace_id}/spans/{span_id}">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">get_span</a>(span_id, \*, trace_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_get_span_response.py">TelemetryGetSpanResponse</a></code>
- <code title="get /v1/telemetry/spans/{span_id}/tree">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">get_span_tree</a>(span_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_get_span_tree_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_get_span_tree_response.py">TelemetryGetSpanTreeResponse</a></code>
- <code title="get /v1/telemetry/traces/{trace_id}">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">get_trace</a>(trace_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/trace.py">Trace</a></code>
- <code title="post /v1/telemetry/events">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">log_event</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_log_event_params.py">params</a>) -> None</code>
- <code title="get /v1/telemetry/spans">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">query_spans</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_spans_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_spans_response.py">TelemetryQuerySpansResponse</a></code>
- <code title="get /v1/telemetry/traces">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">query_traces</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_traces_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_traces_response.py">TelemetryQueryTracesResponse</a></code>
- <code title="post /v1/telemetry/spans/export">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">save_spans_to_dataset</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_save_spans_to_dataset_params.py">params</a>) -> None</code>
## Datasetio
@ -409,8 +409,8 @@ from llama_stack_client.types import PaginatedRowsResult
Methods:
- <code title="post /v1/datasetio/rows">client.datasetio.<a href="./src/llama_stack_client/resources/datasetio.py">append_rows</a>(\*\*<a href="src/llama_stack_client/types/datasetio_append_rows_params.py">params</a>) -> None</code>
- <code title="get /v1/datasetio/rows">client.datasetio.<a href="./src/llama_stack_client/resources/datasetio.py">get_rows_paginated</a>(\*\*<a href="src/llama_stack_client/types/datasetio_get_rows_paginated_params.py">params</a>) -> <a href="./src/llama_stack_client/types/paginated_rows_result.py">PaginatedRowsResult</a></code>
- <code title="post /v1/datasetio/rows">client.datasetio.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasetio.py">append_rows</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/datasetio_append_rows_params.py">params</a>) -> None</code>
- <code title="get /v1/datasetio/rows">client.datasetio.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasetio.py">get_rows_paginated</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/paginated_rows_result.py">PaginatedRowsResult</a></code>
## Scoring
@ -422,8 +422,8 @@ from llama_stack_client.types import ScoringScoreResponse, ScoringScoreBatchResp
Methods:
- <code title="post /v1/scoring/score">client.scoring.<a href="./src/llama_stack_client/resources/scoring.py">score</a>(\*\*<a href="src/llama_stack_client/types/scoring_score_params.py">params</a>) -> <a href="./src/llama_stack_client/types/scoring_score_response.py">ScoringScoreResponse</a></code>
- <code title="post /v1/scoring/score-batch">client.scoring.<a href="./src/llama_stack_client/resources/scoring.py">score_batch</a>(\*\*<a href="src/llama_stack_client/types/scoring_score_batch_params.py">params</a>) -> <a href="./src/llama_stack_client/types/scoring_score_batch_response.py">ScoringScoreBatchResponse</a></code>
- <code title="post /v1/scoring/score">client.scoring.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring.py">score</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_response.py">ScoringScoreResponse</a></code>
- <code title="post /v1/scoring/score-batch">client.scoring.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring.py">score_batch</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_batch_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_batch_response.py">ScoringScoreBatchResponse</a></code>
## ScoringFunctions
@ -439,9 +439,9 @@ from llama_stack_client.types import (
Methods:
- <code title="get /v1/scoring-functions/{scoring_fn_id}">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">retrieve</a>(scoring_fn_id) -> <a href="./src/llama_stack_client/types/scoring_fn.py">Optional[ScoringFn]</a></code>
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="./src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
- <code title="get /v1/scoring-functions/{scoring_fn_id}">client.scoring_functions.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring_functions.py">retrieve</a>(scoring_fn_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_fn.py">Optional[ScoringFn]</a></code>
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
## Benchmarks
@ -457,6 +457,6 @@ from llama_stack_client.types import (
Methods:
- <code title="get /v1/eval-tasks/{benchmark_id}">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">retrieve</a>(benchmark_id) -> <a href="./src/llama_stack_client/types/benchmark.py">Optional[Benchmark]</a></code>
- <code title="get /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">list</a>() -> <a href="./src/llama_stack_client/types/benchmark_list_response.py">BenchmarkListResponse</a></code>
- <code title="post /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">register</a>(\*\*<a href="src/llama_stack_client/types/benchmark_register_params.py">params</a>) -> None</code>
- <code title="get /v1/eval-tasks/{benchmark_id}">client.benchmarks.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/benchmarks.py">retrieve</a>(benchmark_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/benchmark.py">Optional[Benchmark]</a></code>
- <code title="get /v1/eval-tasks">client.benchmarks.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/benchmarks.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/benchmark_list_response.py">BenchmarkListResponse</a></code>
- <code title="post /v1/eval-tasks">client.benchmarks.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/benchmarks.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/benchmark_register_params.py">params</a>) -> None</code>

View file

@ -64,23 +64,3 @@ class Benchmarks(Protocol):
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/eval-tasks", method="GET")
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: ...
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
async def DEPRECATED_get_eval_task(
self,
eval_task_id: str,
) -> Optional[Benchmark]: ...
@webmethod(route="/eval-tasks", method="POST")
async def DEPRECATED_register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...

View file

@ -39,7 +39,6 @@ EvalCandidate = register_schema(
@json_schema_type
class BenchmarkConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
@ -84,28 +83,3 @@ class Eval(Protocol):
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
async def DEPRECATED_run_eval(
self,
task_id: str,
task_config: BenchmarkConfig,
) -> Job: ...
@webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST")
async def DEPRECATED_evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse: ...
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET")
async def DEPRECATED_job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE")
async def DEPRECATED_job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET")
async def DEPRECATED_job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .files import * # noqa: F401 F403

View file

@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class FileUploadResponse(BaseModel):
"""
Response after initiating a file upload session.
:param id: ID of the upload session
:param url: Upload URL for the file or file parts
:param offset: Upload content offset
:param size: Upload content size
"""
id: str
url: str
offset: int
size: int
@json_schema_type
class BucketResponse(BaseModel):
name: str
@json_schema_type
class ListBucketResponse(BaseModel):
"""
Response representing a list of file entries.
:param data: List of FileResponse entries
"""
data: List[BucketResponse]
@json_schema_type
class FileResponse(BaseModel):
"""
Response representing a file entry.
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param mime_type: MIME type of the file
:param url: Upload URL for the file contents
:param bytes: Size of the file in bytes
:param created_at: Timestamp of when the file was created
"""
bucket: str
key: str
mime_type: str
url: str
bytes: int
created_at: int
@json_schema_type
class ListFileResponse(BaseModel):
"""
Response representing a list of file entries.
:param data: List of FileResponse entries
"""
data: List[FileResponse]
@runtime_checkable
@trace_protocol
class Files(Protocol):
@webmethod(route="/files", method="POST")
async def create_upload_session(
self,
bucket: str,
key: str,
mime_type: str,
size: int,
) -> FileUploadResponse:
"""
Create a new upload session for a file identified by a bucket and key.
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
:param mime_type: MIME type of the file
:param size: File size in bytes
"""
...
@webmethod(route="/files/session:{upload_id}", method="POST", raw_bytes_request_body=True)
async def upload_content_to_session(
self,
upload_id: str,
) -> Optional[FileResponse]:
"""
Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded.
:param upload_id: ID of the upload session
"""
...
@webmethod(route="/files/session:{upload_id}", method="GET")
async def get_upload_session_info(
self,
upload_id: str,
) -> Optional[FileUploadResponse]:
"""
Returns information about an existsing upload session
:param upload_id: ID of the upload session
"""
...
@webmethod(route="/files", method="GET")
async def list_all_buckets(
self,
bucket: str,
) -> ListBucketResponse:
"""
List all buckets.
"""
...
@webmethod(route="/files/{bucket}", method="GET")
async def list_files_in_bucket(
self,
bucket: str,
) -> ListFileResponse:
"""
List all files in a bucket.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
"""
...
@webmethod(route="/files/{bucket}/{key:path}", method="GET")
async def get_file(
self,
bucket: str,
key: str,
) -> FileResponse:
"""
Get a file info identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
"""
...
@webmethod(route="/files/{bucket}/{key:path}", method="DELETE")
async def delete_file(
self,
bucket: str,
key: str,
) -> FileResponse:
"""
Delete a file identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
"""
...

View file

@ -216,7 +216,7 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST")
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
@webmethod(route="/telemetry/traces", method="GET")
@webmethod(route="/telemetry/traces", method="POST")
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
@ -231,7 +231,7 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
async def get_span_tree(
self,
span_id: str,
@ -239,7 +239,7 @@ class Telemetry(Protocol):
max_depth: Optional[int] = None,
) -> QuerySpanTreeResponse: ...
@webmethod(route="/telemetry/spans", method="GET")
@webmethod(route="/telemetry/spans", method="POST")
async def query_spans(
self,
attribute_filters: List[QueryCondition],

View file

@ -5,12 +5,44 @@
# the root directory of this source tree.
import argparse
import os
import time
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import all_registered_models
def _get_model_size(model_dir):
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
def _run_model_list_downloaded_cmd() -> None:
headers = ["Model", "Size", "Modified Time"]
rows = []
for model in os.listdir(DEFAULT_CHECKPOINT_DIR):
abs_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
space_usage = _get_model_size(abs_path)
model_size = f"{space_usage / (1024**3):.2f} GB"
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
rows.append(
[
model,
model_size,
modified_time,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
class ModelList(Subcommand):
"""List available llama models"""
@ -31,10 +63,18 @@ class ModelList(Subcommand):
action="store_true",
help="Show all models (not just defaults)",
)
self.parser.add_argument(
"--downloaded",
action="store_true",
help="List the downloaded models",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
if args.downloaded:
return _run_model_list_downloaded_cmd()
headers = [
"Model Descriptor(ID)",
"Hugging Face Repo",

View file

@ -15,6 +15,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
@ -73,11 +74,16 @@ run() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
if [ -n "$UV_SYSTEM_PYTHON" ]; then
echo "Installing dependencies in system Python environment"
else
echo "Using virtual environment $env_name"
uv venv "$env_name"
# shellcheck source=/dev/null
source "$env_name/bin/activate"
fi
echo "Using virtual environment $env_name"
uv venv "$env_name"
# shellcheck source=/dev/null
source "$env_name/bin/activate"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst

View file

@ -10,12 +10,12 @@ cleanup() {
set +x
echo "Cleaning up..."
conda deactivate
conda env remove --name $envname -y
conda env remove --name "$envname" -y
}
handle_int() {
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
exit 1
}
@ -23,8 +23,8 @@ handle_int() {
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n $ENVNAME ]; then
cleanup $ENVNAME
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
fi
}
@ -33,10 +33,14 @@ setup_cleanup_handlers() {
trap handle_int INT
trap handle_exit EXIT
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
if is_command_available conda; then
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
else
echo "conda is not available"
exit 1
fi
}
# check if a command is present

View file

@ -411,48 +411,6 @@ class EvalRouter(Eval):
job_id,
)
async def DEPRECATED_run_eval(
self,
task_id: str,
task_config: BenchmarkConfig,
) -> Job:
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
async def DEPRECATED_evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse:
return await self.evaluate_rows(
benchmark_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def DEPRECATED_job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.job_status(benchmark_id=task_id, job_id=job_id)
async def DEPRECATED_job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
async def DEPRECATED_job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.job_result(benchmark_id=task_id, job_id=job_id)
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):

View file

@ -468,35 +468,6 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
)
await self.register_object(benchmark)
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.list_benchmarks()
async def DEPRECATED_get_eval_task(
self,
eval_task_id: str,
) -> Optional[Benchmark]:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.get_benchmark(eval_task_id)
async def DEPRECATED_register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.register_benchmark(
benchmark_id=eval_task_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_benchmark_id=provider_benchmark_id,
)
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:

View file

@ -481,6 +481,8 @@ def main():
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
# to handle path params like {param:path}
params = [param.split(":")[0] for param in params]
return params

View file

@ -19,6 +19,7 @@ from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
@ -63,6 +64,7 @@ class LlamaStack(
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
):
pass

View file

@ -34,6 +34,7 @@ def _run_with_pty_unix(command):
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
process = None
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
@ -98,7 +99,7 @@ def _run_with_pty_unix(command):
signal.signal(signal.SIGINT, original_sigint)
os.close(master)
if process.poll() is None:
if process and process.poll() is None:
process.terminate()
process.wait()

View file

@ -234,45 +234,3 @@ class MetaReferenceEvalImpl(
raise ValueError(f"Job is not completed, Status: {status.value}")
return self.jobs[job_id]
async def DEPRECATED_run_eval(
self,
task_id: str,
task_config: BenchmarkConfig,
) -> Job:
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
async def DEPRECATED_evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse:
return await self.evaluate_rows(
benchmark_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def DEPRECATED_job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.job_status(benchmark_id=task_id, job_id=job_id)
async def DEPRECATED_job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
async def DEPRECATED_job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.job_result(benchmark_id=task_id, job_id=job_id)

View file

@ -46,7 +46,7 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
@ -116,7 +116,7 @@ class MetaReferenceInferenceImpl(
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
build_hf_repo_model_entry(
llama_model.descriptor(),
llama_model.core_model_id.value,
)

View file

@ -9,7 +9,6 @@ import os
import uuid
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMConfig):
self.config = config
self.engine = None
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self):
log.info("Initializing vLLM inference provider.")
@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
prompt = await chat_completion_request_to_prompt(request, self.config.model)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream:
@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
tokenizer = Tokenizer.get_instance()
async def _generate_and_convert_to_openai_compat():
cur = []
async for chunk in results_generator:
@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
output = chunk.outputs[-1]
new_tokens = output.token_ids[len(cur) :]
text = self.formatter.tokenizer.decode(new_tokens)
text = tokenizer.decode(new_tokens)
cur.extend(new_tokens)
choice = OpenAICompatCompletionChoice(
finish_reason=output.finish_reason,
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:

View file

@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl:
self.datasets_api = datasets
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs_status = {}
self.jobs_list = []
self.jobs = {}
self.checkpoints_dict = {}
async def supervised_fine_tune(
@ -54,9 +53,8 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
for job in self.jobs_list:
if job_uuid == job.job_uuid:
raise ValueError(f"Job {job_uuid} already exists")
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
@ -65,8 +63,8 @@ class TorchtunePostTrainingImpl:
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
)
self.jobs[job_uuid] = job_status_response
self.jobs_list.append(post_training_job)
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
recipe = LoraFinetuningSingleDevice(
@ -100,8 +98,6 @@ class TorchtunePostTrainingImpl:
else:
raise NotImplementedError()
self.jobs_status[job_uuid] = job_status_response
return post_training_job
async def preference_optimize(
@ -115,13 +111,11 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=self.jobs_list)
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
return self.jobs.get(job_uuid, None)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:

View file

@ -64,8 +64,6 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice:
# This recipe only supports GPU training
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile
# - activation offloading
@ -93,7 +91,7 @@ class LoraFinetuningSingleDevice:
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
self.model_id = model
@ -231,6 +229,13 @@ class LoraFinetuningSingleDevice:
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
if self._device.type == "cpu":
return
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
async def _setup_model(
self,
enable_activation_checkpointing: bool,
@ -293,8 +298,7 @@ class LoraFinetuningSingleDevice:
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
self._log_memory_stats()
return model
@ -506,8 +510,7 @@ class LoraFinetuningSingleDevice:
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
self._log_memory_stats()
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})

View file

@ -5,7 +5,11 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
LLMAsJudgeScoringFnParams,
ScoringFn,
)
GRADER_TEMPLATE = """
Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
@ -87,5 +91,6 @@ llm_as_judge_405b_simpleqa = ScoringFn(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template=GRADER_TEMPLATE,
judge_score_regexes=[r"(A|B|C)"],
aggregation_functions=[AggregationFunctionType.categorical_count.value],
),
)

View file

@ -9,10 +9,11 @@ from typing import Any, Dict
from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig
from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
await impl.initialize()
return impl

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import logging
import sqlite3
import struct
import uuid
from typing import Any, Dict, List, Optional
import numpy as np
@ -52,14 +54,14 @@ class SQLiteVecIndex(EmbeddingIndex):
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id INTEGER PRIMARY KEY,
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}]);
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
self.connection.commit()
@ -69,9 +71,9 @@ class SQLiteVecIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit()
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
"""
Add new chunks along with their embeddings.
Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency.
@ -80,21 +82,35 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")
for chunk, emb in zip(chunks, embeddings, strict=False):
# Serialize and insert the chunk metadata.
chunk_json = chunk.model_dump_json()
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
row_id = cur.lastrowid
# Ensure the embedding is a list of floats.
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
emb_blob = serialize_vector(emb_list)
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
# Commit transaction if all inserts succeed
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
metadata_data,
)
# Prepare embeddings inserts
embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
self.connection.commit()
except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
logger.error(f"Error inserting into {self.vector_table}: {e}")
finally:
cur.close() # Ensure cursor is closed
@ -110,7 +126,7 @@ class SQLiteVecIndex(EmbeddingIndex):
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.rowid
JOIN {self.metadata_table} AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
@ -204,7 +220,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our indexs add_chunks.
# and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks)
async def query_chunks(
@ -213,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))

View file

@ -215,4 +215,14 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="passthrough",
pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
),
),
]

View file

@ -8,8 +8,6 @@ import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -27,12 +25,10 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -47,29 +43,15 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
MODEL_ALIASES = [
build_model_alias(
"meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta.llama3-1-70b-instruct-v1:0",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value,
),
]
from .models import MODEL_ENTRIES
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self._config = config
self._client = create_bedrock_client(config)
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> BaseClient:
@ -134,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
@ -152,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
@ -166,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
return {
"modelId": bedrock_model,
"body": json.dumps(

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta.llama3-1-70b-instruct-v1:0",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value,
),
]

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -26,10 +24,9 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy
from llama_stack.models.llama.datatypes import TopKSamplingStrategy
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -44,27 +41,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import CerebrasImplConfig
model_aliases = [
build_model_alias(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
]
from .models import model_entries
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
model_entries=model_entries,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras(
base_url=self.config.base_url,
@ -107,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
@ -154,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -170,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request, self.formatter)
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
model_entries = [
build_hf_repo_model_entry(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
]

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
@ -27,7 +25,7 @@ from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -40,12 +38,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig
model_aliases = [
build_model_alias(
model_entries = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
@ -54,12 +52,8 @@ model_aliases = [
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
ModelRegistryHelper.__init__(self, model_entries=model_entries)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter),
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -29,10 +27,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
@ -51,56 +47,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import FireworksImplConfig
MODEL_ALIASES = [
build_model_alias(
"accounts/fireworks/models/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
from .models import MODEL_ENTRIES
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
@ -149,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -161,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
@ -230,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -244,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -258,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request, self.get_llama_model(request.model)
)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
@ -284,8 +237,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
if model.metadata.get("embedding_dimension"):
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings"
)

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry(
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
]

View file

@ -31,8 +31,8 @@ from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_model_alias_with_just_provider_model_id,
build_hf_repo_model_entry,
build_model_entry,
)
from .groq_utils import (
@ -41,20 +41,20 @@ from .groq_utils import (
convert_chat_completion_response_stream,
)
_MODEL_ALIASES = [
build_model_alias(
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
build_model_entry(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
build_hf_repo_model_entry(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
build_hf_repo_model_entry(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
@ -62,7 +62,7 @@ _MODEL_ALIASES = [
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_alias(
build_hf_repo_model_entry(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
@ -73,7 +73,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
_config: GroqConfig
def __init__(self, config: GroqConfig):
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
self._config = config
def completion(

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama3-8b-instruct",
CoreModelId.llama3_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama3-70b-instruct",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]

View file

@ -26,19 +26,14 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolConfig,
)
from llama_stack.models.llama.datatypes import (
CoreModelId,
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
from .models import _MODEL_ENTRIES
from .openai_utils import (
convert_chat_completion_request,
convert_completion_request,
@ -51,52 +46,11 @@ from .utils import _is_nvidia_hosted, check_health
logger = logging.getLogger(__name__)
_MODEL_ALIASES = [
build_model_alias(
"meta/llama3-8b-instruct",
CoreModelId.llama3_8b_instruct.value,
),
build_model_alias(
"meta/llama3-70b-instruct",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
build_model_entry,
)
model_entries = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry(
provider_model_id="all-minilm:latest",
aliases=["all-minilm"],
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nomic-embed-text",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
]

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.common.content_types import (
@ -33,12 +31,9 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_model_alias_with_just_provider_model_id,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -58,87 +53,15 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
log = logging.getLogger(__name__)
from .models import model_entries
model_aliases = [
build_model_alias(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_model_alias(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
log = logging.getLogger(__name__)
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_aliases)
self.register_helper = ModelRegistryHelper(model_entries)
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> AsyncClient:
@ -197,7 +120,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -212,7 +135,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
choices=[choice],
)
return process_completion_response(response, self.formatter)
return process_completion_response(response)
async def chat_completion(
self,
@ -262,11 +185,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
input_dict["raw"] = True
if fmt := request.response_format:
@ -304,7 +226,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -330,7 +252,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(
@ -352,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
async def check_model_availability(model_id: str):
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model_id not in available_models:
raise ValueError(
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
if model.model_type == ModelType.embedding:
await check_model_availability(model.provider_resource_id)
return model
response = await self.client.list()
else:
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
model = await self.register_helper.register_model(model)
await check_model_availability(model.provider_resource_id)
return model
return await self.register_helper.register_model(model)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import PassthroughImplConfig
class PassthroughProviderDataValidator(BaseModel):
url: str
api_key: str
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
from .passthrough import PassthroughInferenceAdapter
assert isinstance(config, PassthroughImplConfig), f"Unexpected config type: {type(config)}"
impl = PassthroughInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PassthroughImplConfig(BaseModel):
url: str = Field(
default=None,
description="The URL for the passthrough endpoint",
)
api_key: Optional[SecretStr] = Field(
default=None,
description="API Key for the passthrouth endpoint",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.PASSTHROUGH_URL}",
"api_key": "${env.PASSTHROUGH_API_KEY}",
}

View file

@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional
from llama_stack_client import LlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig
class PassthroughInferenceAdapter(Inference):
def __init__(self, config: PassthroughImplConfig) -> None:
ModelRegistryHelper.__init__(self, [])
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> LlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
if self.config.url is not None:
passthrough_url = self.config.url
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_url:
raise ValueError(
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
)
passthrough_url = provider_data.passthrough_url
if self.config.api_key is not None:
passthrough_api_key = self.config.api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_api_key:
raise ValueError(
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
)
passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
"model_id": model.provider_resource_id,
"content": content,
"sampling_params": sampling_params,
"response_format": response_format,
"stream": stream,
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params
return client.inference.completion(**params)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
"tools": tools,
"tool_choice": tool_choice,
"tool_prompt_format": tool_prompt_format,
"response_format": response_format,
"stream": stream,
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params
return client.inference.chat_completion(**params)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
client = self._get_client()
model = await self.model_store.get_model(model_id)
return client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
)

View file

@ -5,8 +5,6 @@
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"prompt": chat_completion_request_to_prompt(request),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -7,7 +7,6 @@
from pydantic import BaseModel
from .config import SambaNovaImplConfig
from .sambanova import SambaNovaInferenceAdapter
class SambaNovaProviderDataValidator(BaseModel):
@ -15,6 +14,8 @@ class SambaNovaProviderDataValidator(BaseModel):
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
await impl.initialize()

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.1-70B-Instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]

View file

@ -7,8 +7,6 @@
import json
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import (
@ -18,14 +16,12 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.models.llama.datatypes import (
CoreModelId,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
@ -35,56 +31,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import SambaNovaImplConfig
MODEL_ALIASES = [
build_model_alias(
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-70B-Instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=MODEL_ALIASES,
)
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -160,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(

View file

@ -7,13 +7,14 @@
from typing import Union
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
async def get_adapter_impl(
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
_deps,
):
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
if isinstance(config, TGIImplConfig):
impl = TGIAdapter()
elif isinstance(config, InferenceAPIImplConfig):

View file

@ -9,8 +9,6 @@ import logging
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -34,7 +32,7 @@ from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -55,9 +53,9 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
log = logging.getLogger(__name__)
def build_model_aliases():
def build_hf_repo_model_entries():
return [
build_model_alias(
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
@ -72,8 +70,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model_id: str
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter)
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
return dict(
prompt=prompt,
@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
choices=[choice],
)
return process_completion_response(response, self.formatter)
return process_completion_response(response)
async def chat_completion(
self,
@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model), self.formatter
request, self.register_helper.get_llama_model(request.model)
)
return dict(
prompt=prompt,

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 32768,
},
),
]

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@ -28,10 +26,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
@ -50,52 +46,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import TogetherImplConfig
MODEL_ALIASES = [
build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
from .models import MODEL_ENTRIES
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
@ -142,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -154,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
@ -220,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
r = self._get_client().chat.completions.create(**params)
else:
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -235,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -246,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request, self.get_llama_model(request.model)
)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
return {
"model": request.model,

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
@ -40,7 +38,7 @@ from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_model_alias,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse,
@ -64,9 +62,9 @@ from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__)
def build_model_aliases():
def build_hf_repo_model_entries():
return [
build_model_alias(
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
@ -150,19 +148,36 @@ async def _process_vllm_chat_completion_stream_response(
async for chunk in stream:
choice = chunk.choices[0]
if choice.finish_reason:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=json.loads(tool_call_buf.arguments),
args_str = tool_call_buf.arguments
args = None
try:
args = {} if not args_str else json.loads(args_str)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
if args is not None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
),
parse_status=ToolCallParseStatus.succeeded,
),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=str(tool_call_buf),
parse_status=ToolCallParseStatus.failed,
),
)
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
@ -189,9 +204,8 @@ async def _process_vllm_chat_completion_stream_response(
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
async def initialize(self) -> None:
@ -286,14 +300,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if len(request.tools) > 0:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, self.formatter, request)
res = process_chat_completion_stream_response(stream, request)
async for chunk in res:
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -305,7 +319,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
@ -332,10 +346,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request,
self.formatter,
)
input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
@ -364,8 +375,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
kwargs = {}
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert model.metadata.get("embedding_dimension")
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id,

View file

@ -7,7 +7,6 @@
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
class ModelContextProtocolToolProviderDataValidator(BaseModel):
@ -15,6 +14,8 @@ class ModelContextProtocolToolProviderDataValidator(BaseModel):
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
impl = ModelContextProtocolToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -12,6 +12,20 @@ We use `pytest` and all of its dynamism to enable the features needed. Specifica
- We use `pytest_collection_modifyitems` to filter tests based on the test config (if specified).
## Pre-requisites
Your development environment should have been configured as per the instructions in the
[CONTRIBUTING.md](../../../CONTRIBUTING.md) file. In particular, make sure to install the test extra
dependencies. Below is the full configuration:
```bash
$ cd llama-stack
$ uv sync --extra dev --extra test
$ uv pip install -e .
$ source .venv/bin/activate
```
## Common options
All tests support a `--providers` option which can be a string of the form `api1=provider_fixture1,api2=provider_fixture2`. So, when testing safety (which need inference and safety APIs) you can use `--providers inference=together,safety=meta_reference` to use these fixtures in concert.
@ -50,6 +64,9 @@ pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
--env FIREWORKS_API_KEY=<...>
```
> [!TIP]
> If youre using `uv`, you can isolate test executions by prefixing all commands with `uv run pytest...`.
## Agents
The Agents API composes three other APIs underneath:
@ -87,3 +104,6 @@ pytest llama_stack/providers/tests/ --config=ci_test_config.yaml
Currently, we support test config on inference, agents and memory api tests.
Example format of test config can be found in ci_test_config.yaml.
## Test Data
We encourage providers to use our test data for internal development testing, so to make it easier and consistent with the tests we provide. Each test case may define its own data format, and please refer to our test source code to get details on how these fields are used in the test.

View file

@ -20,7 +20,7 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
@ -83,17 +83,13 @@ def inference_cerebras() -> ProviderFixture:
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
def inference_ollama() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig(host="localhost", port=os.getenv("OLLAMA_PORT", 11434)).model_dump(),
config=OllamaImplConfig(url=os.getenv("OLLAMA_URL", DEFAULT_OLLAMA_URL)).model_dump(),
)
],
)

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
# How to run this test:
@ -15,6 +13,9 @@ import pytest
class TestModelRegistration:
def provider_supports_custom_names(self, provider) -> bool:
return "remote::ollama" not in provider.__provider_spec__.provider_type
@pytest.mark.asyncio
async def test_register_unsupported_model(self, inference_stack, inference_model):
inference_impl, models_impl = inference_stack
@ -47,7 +48,12 @@ class TestModelRegistration:
)
@pytest.mark.asyncio
async def test_register_with_llama_model(self, inference_stack):
async def test_register_with_llama_model(self, inference_stack, inference_model):
inference_impl, models_impl = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if not self.provider_supports_custom_names(provider):
pytest.skip("Provider does not support custom model names")
_, models_impl = inference_stack
_ = await models_impl.register_model(
@ -67,22 +73,6 @@ class TestModelRegistration:
provider_model_id="custom-model",
)
@pytest.mark.asyncio
async def test_initialize_model_during_registering(self, inference_stack):
_, models_impl = inference_stack
with patch(
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
new_callable=AsyncMock,
) as mock_load_model:
_ = await models_impl.register_model(
model_id="Llama3.1-8B-Instruct",
metadata={
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
},
)
mock_load_model.assert_called_once()
@pytest.mark.asyncio
async def test_register_with_invalid_llama_model(self, inference_stack):
_, models_impl = inference_stack

View file

@ -6,7 +6,7 @@
import pytest
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, TypeAdapter, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
SystemMessage,
ToolChoice,
UserMessage,
@ -30,6 +31,7 @@ from llama_stack.models.llama.datatypes import (
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.tests.test_cases.test_case import TestCase
from .utils import group_chunks
@ -178,8 +180,9 @@ class TestInference:
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_structured_output(self, inference_model, inference_stack):
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack
class Output(BaseModel):
@ -187,7 +190,9 @@ class TestInference:
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
tc = TestCase(test_case)
user_input = tc["user_input"]
response = await inference_impl.completion(
model_id=inference_model,
content=user_input,
@ -203,9 +208,10 @@ class TestInference:
assert isinstance(response.content, str)
answer = Output.model_validate_json(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
expected = tc["expected"]
assert answer.name == expected["name"]
assert answer.year_born == expected["year_born"]
assert answer.year_retired == expected["year_retired"]
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
@ -224,8 +230,9 @@ class TestInference:
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(self, inference_model, inference_stack, common_params):
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
class AnswerFormat(BaseModel):
@ -234,20 +241,12 @@ class TestInference:
year_of_birth: int
num_seasons_in_nba: int
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
messages=messages,
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
@ -260,10 +259,11 @@ class TestInference:
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
expected = tc["expected"]
assert answer.first_name == expected["first_name"]
assert answer.last_name == expected["last_name"]
assert answer.year_of_birth == expected["year_of_birth"]
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
response = await inference_impl.chat_completion(
model_id=inference_model,

View file

@ -83,7 +83,6 @@ def pytest_generate_tests(metafunc):
if "safety_shield" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if shield_id:
assert shield_id.startswith("meta-llama/")
params = [pytest.param(shield_id, id="")]
else:
params = SAFETY_SHIELD_PARAMS

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,24 @@
{
"01": {
"name": "structured output",
"data": {
"notes": "We include context about Michael Jordan in the prompt so that the test is focused on the funtionality of the model and not on the information embedded in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
},
{
"role": "user",
"content": "Please give me information about Michael Jordan."
}
],
"expected": {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}
}
}
}

View file

@ -0,0 +1,13 @@
{
"01": {
"name": "structured output",
"data": {
"user_input": "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.",
"expected": {
"name": "Michael Jordan",
"year_born": "1963",
"year_retired": "2003"
}
}
}
}

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pathlib
class TestCase:
_apis = ["chat_completion", "completion"]
_jsonblob = {}
def __init__(self, name):
# loading all test cases
if self._jsonblob == {}:
for api in self._apis:
with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f:
TestCase._jsonblob.update({f"{api}-{k}": v for k, v in json.load(f).items()})
# loading this test case
tc = self._jsonblob.get(name)
if tc is None:
raise ValueError(f"Test case {name} not found")
# these are the only fields we need
self.name = tc.get("name")
self.data = tc.get("data")
def __getitem__(self, key):
return self.data[key]

View file

@ -0,0 +1,160 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import sqlite3
import numpy as np
import pytest
import sqlite_vec
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
generate_chunk_id,
)
# How to run this test:
#
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
EMBEDDING_DIMENSION = 384
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection):
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
@pytest.fixture(scope="session")
def sample_chunks():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
await adapter.initialize()
yield adapter
await adapter.shutdown()
@pytest.mark.asyncio
async def test_register_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert any(db.identifier == "test_db" for db in vector_dbs)
@pytest.mark.asyncio
async def test_unregister_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
await sqlite_vec_adapter.unregister_vector_db("test_db")
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert not any(db.identifier == "test_db" for db in vector_dbs)
def test_generate_chunk_id():
chunks = [
Chunk(content="test", metadata={"document_id": "doc-1"}),
Chunk(content="test ", metadata={"document_id": "doc-1"}),
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
]
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections import namedtuple
from typing import List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.sku_list import all_registered_models
@ -14,7 +15,15 @@ from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
# TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class.
class ProviderModelEntry(BaseModel):
provider_model_id: str
aliases: List[str] = Field(default_factory=list)
llama_model: Optional[str] = None
model_type: ModelType = ModelType.llm
metadata: Dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
@ -24,8 +33,8 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
return None
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
return ModelAlias(
def build_hf_repo_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=[
get_huggingface_repo(model_descriptor),
@ -34,26 +43,29 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
)
def build_model_alias_with_just_provider_model_id(provider_model_id: str, model_descriptor: str) -> ModelAlias:
return ModelAlias(
def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=[],
llama_model=model_descriptor,
model_type=ModelType.llm,
)
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_aliases: List[ModelAlias]):
def __init__(self, model_entries: List[ProviderModelEntry]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for alias_obj in model_aliases:
for alias in alias_obj.aliases:
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
for entry in model_entries:
for alias in entry.aliases:
self.alias_to_provider_id_map[alias] = entry.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
# ensure we can go from llama model to provider model id
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
if entry.llama_model:
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
def get_provider_model_id(self, identifier: str) -> Optional[str]:
return self.alias_to_provider_id_map.get(identifier, None)

View file

@ -7,7 +7,6 @@ import json
import logging
from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel
@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import (
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
decode_assistant_message,
)
logger = logging.getLogger(__name__)
@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio
return None
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse:
choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
def process_chat_completion_response(
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
choice = response.choices[0]
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
@ -217,7 +214,7 @@ def process_chat_completion_response(
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator:
stop_reason = None
@ -254,7 +251,6 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
@ -333,7 +329,7 @@ async def process_chat_completion_stream_response(
)
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
message = decode_assistant_message(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:

View file

@ -13,7 +13,9 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import StopReason
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
formatter = ChatFormat(Tokenizer.get_instance())
return formatter.decode_assistant_message_from_content(content, stop_reason)
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str:
if isinstance(c, str):
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content)
return formatter.tokenizer.decode(model_input.tokens)
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> str:
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
request: ChatCompletionRequest, llama_model: str
) -> Tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
return (
formatter.tokenizer.decode(model_input.tokens),

View file

@ -18,6 +18,7 @@ class KVStoreType(Enum):
redis = "redis"
sqlite = "sqlite"
postgres = "postgres"
mongodb = "mongodb"
class CommonConfig(BaseModel):
@ -101,7 +102,30 @@ class PostgresKVStoreConfig(CommonConfig):
return v
class MongoDBKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
host: str = "localhost"
port: int = 27017
db: str = "llamastack"
user: str = None
password: Optional[str] = None
collection_name: str = "llamastack_kvstore"
@classmethod
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
return {
"type": "mongodb",
"namespace": None,
"host": "${env.MONGODB_HOST:localhost}",
"port": "${env.MONGODB_PORT:5432}",
"db": "${env.MONGODB_DB}",
"user": "${env.MONGODB_USER}",
"password": "${env.MONGODB_PASSWORD}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}",
}
KVStoreConfig = Annotated[
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig],
Field(discriminator="type", default=KVStoreType.sqlite.value),
]

View file

@ -11,7 +11,7 @@ from .config import KVStoreConfig, KVStoreType
def kvstore_dependencies():
return ["aiosqlite", "psycopg2-binary", "redis"]
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
class InmemoryKVStoreImpl(KVStore):
@ -44,6 +44,10 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore:
from .postgres import PostgresKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == KVStoreType.mongodb.value:
from .mongodb import MongoDBKVStoreImpl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .mongodb import MongoDBKVStoreImpl
__all__ = ["MongoDBKVStoreImpl"]

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from datetime import datetime
from typing import List, Optional
from pymongo import MongoClient
from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig
log = logging.getLogger(__name__)
class MongoDBKVStoreImpl(KVStore):
def __init__(self, config: MongoDBKVStoreConfig):
self.config = config
self.conn = None
self.collection = None
async def initialize(self) -> None:
try:
conn_creds = {
"host": self.config.host,
"port": self.config.port,
"username": self.config.user,
"password": self.config.password,
}
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
self.conn = MongoClient(**conn_creds)
self.collection = self.conn[self.config.db][self.config.collection_name]
except Exception as e:
log.exception("Could not connect to MongoDB database server")
raise RuntimeError("Could not connect to MongoDB database server") from e
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
key = self._namespaced_key(key)
update_query = {"$set": {"value": value, "expiration": expiration}}
self.collection.update_one({"key": key}, update_query, upsert=True)
async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key)
query = {"key": key}
result = self.collection.find_one(query, {"value": 1, "_id": 0})
return result["value"] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {
"key": {"$gte": start_key, "$lt": end_key},
}
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
return [doc["value"] for doc in cursor]

View file

@ -19,6 +19,7 @@ class WebMethod:
request_examples: Optional[List[Any]] = None
response_examples: Optional[List[Any]] = None
method: Optional[str] = None
raw_bytes_request_body: Optional[bool] = False
def webmethod(
@ -27,6 +28,7 @@ def webmethod(
public: Optional[bool] = False,
request_examples: Optional[List[Any]] = None,
response_examples: Optional[List[Any]] = None,
raw_bytes_request_body: Optional[bool] = False,
) -> Callable[[T], T]:
"""
Decorator that supplies additional metadata to an endpoint operation function.
@ -44,6 +46,7 @@ def webmethod(
public=public or False,
request_examples=request_examples,
response_examples=response_examples,
raw_bytes_request_body=raw_bytes_request_body,
)
return cls

View file

@ -23,6 +23,22 @@ from llama_stack.distribution.build import (
REPO_ROOT = Path(__file__).parent.parent.parent
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
def __init__(self):
self._changed_paths = []
def add_paths(self, *paths):
for path in paths:
path = str(path)
if path not in self._changed_paths:
self._changed_paths.append(path)
def changed_paths(self):
return self._changed_paths
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
"""Find immediate subdirectories in the templates folder."""
if not templates_dir.exists():
@ -31,7 +47,7 @@ def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
return sorted(d for d in templates_dir.iterdir() if d.is_dir() and d.name != "__pycache__")
def process_template(template_dir: Path, progress) -> None:
def process_template(template_dir: Path, progress, change_tracker: ChangedPathTracker) -> None:
"""Process a single template directory."""
progress.print(f"Processing {template_dir.name}")
@ -44,9 +60,12 @@ def process_template(template_dir: Path, progress) -> None:
if template_func := getattr(module, "get_distribution_template", None):
template = template_func()
yaml_output_dir = REPO_ROOT / "llama_stack" / "templates" / template.name
doc_output_dir = REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro"
change_tracker.add_paths(yaml_output_dir, doc_output_dir)
template.save_distribution(
yaml_output_dir=REPO_ROOT / "llama_stack" / "templates" / template.name,
doc_output_dir=REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro",
yaml_output_dir=yaml_output_dir,
doc_output_dir=doc_output_dir,
)
else:
progress.print(f"[yellow]Warning: {template_dir.name} has no get_distribution_template function")
@ -56,14 +75,19 @@ def process_template(template_dir: Path, progress) -> None:
raise e
def check_for_changes() -> bool:
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
"""Check if there are any uncommitted changes."""
result = subprocess.run(
["git", "diff", "--exit-code"],
cwd=REPO_ROOT,
capture_output=True,
)
return result.returncode != 0
has_changes = False
for path in change_tracker.changed_paths():
result = subprocess.run(
["git", "diff", "--exit-code", path],
cwd=REPO_ROOT,
capture_output=True,
)
if result.returncode != 0:
print(f"Change detected in '{path}'.", file=sys.stderr)
has_changes = True
return has_changes
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
@ -83,7 +107,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
return None, []
def generate_dependencies_file():
def generate_dependencies_file(change_tracker: ChangedPathTracker):
templates_dir = REPO_ROOT / "llama_stack" / "templates"
distribution_deps = {}
@ -93,12 +117,14 @@ def generate_dependencies_file():
distribution_deps[name] = deps
deps_file = REPO_ROOT / "distributions" / "dependencies.json"
change_tracker.add_paths(deps_file)
with open(deps_file, "w") as f:
f.write(json.dumps(distribution_deps, indent=2) + "\n")
def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates"
change_tracker = ChangedPathTracker()
with Progress(
SpinnerColumn(),
@ -108,7 +134,7 @@ def main():
task = progress.add_task("Processing distribution templates...", total=len(template_dirs))
# Create a partial function with the progress bar
process_func = partial(process_template, progress=progress)
process_func = partial(process_template, progress=progress, change_tracker=change_tracker)
# Process templates in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
@ -116,9 +142,9 @@ def main():
list(executor.map(process_func, template_dirs))
progress.update(task, advance=len(template_dirs))
generate_dependencies_file()
generate_dependencies_file(change_tracker)
if check_for_changes():
if check_for_changes(change_tracker):
print(
"Distribution template changes detected. Please commit the changes.",
file=sys.stderr,

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
from pathlib import Path
import pytest
"""
Script for running client-sdk on AsyncLlamaStackAsLibraryClient with templates
Assuming directory structure:
- llama-stack
- llama_stack
- scripts
- tests
- client-sdk
Example command:
cd llama-stack
EXPORT TOGETHER_API_KEY=<..>
EXPORT FIREWORKS_API_KEY=<..>
python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report
"""
REPO_ROOT = Path(__file__).parent.parent.parent
CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/client-sdk/"
def main(parser: argparse.ArgumentParser):
args = parser.parse_args()
templates_dir = REPO_ROOT / "llama_stack" / "templates"
user_specified_templates = [templates_dir / t for t in args.templates] if args.templates else []
for d in templates_dir.iterdir():
if d.is_dir() and d.name != "__pycache__":
template_configs = list(d.rglob("run.yaml"))
if len(template_configs) == 0:
continue
config = template_configs[0]
if user_specified_templates:
if not any(config.parent == t for t in user_specified_templates):
continue
os.environ["LLAMA_STACK_CONFIG"] = str(config)
pytest_args = "--report" if args.report else ""
pytest.main(
[
pytest_args,
"-s",
"-v",
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama_test",
)
parser.add_argument("--templates", nargs="+")
parser.add_argument("--report", action="store_true")
main(parser)

View file

@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -47,7 +47,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_model_id=m.provider_model_id,
provider_id="bedrock",
)
for m in MODEL_ALIASES
for m in MODEL_ENTRIES
]
default_tool_groups = [
ToolGroupInput(

View file

@ -14,7 +14,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
from llama_stack.providers.remote.inference.cerebras.models import model_entries
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -55,7 +55,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_model_id=m.provider_model_id,
provider_id="cerebras",
)
for m in model_aliases
for m in model_entries
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",

View file

@ -19,7 +19,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -63,11 +63,13 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id="fireworks",
metadata=m.metadata,
model_type=m.model_type,
)
for m in MODEL_ALIASES
for m in MODEL_ENTRIES
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",

View file

@ -149,6 +149,13 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -143,6 +143,13 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -10,7 +10,7 @@ from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES
from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -59,7 +59,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_model_id=m.provider_model_id,
provider_id="nvidia",
)
for m in _MODEL_ALIASES
for m in _MODEL_ENTRIES
]
default_tool_groups = [
ToolGroupInput(

View file

@ -6,7 +6,6 @@ distribution_spec:
- remote::ollama
vector_io:
- inline::faiss
- inline::sqlite_vec
- remote::chromadb
- remote::pgvector
safety:

View file

@ -119,7 +119,7 @@ llama stack run ./run-with-safety.yaml \
### (Optional) Update Model Serving Configuration
```{note}
Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
```
To serve a new model with `ollama`

View file

@ -71,7 +71,8 @@ def get_distribution_template() -> DistributionTemplate:
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
provider_id="ollama",
provider_model_id="all-minilm:latest",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,

View file

@ -20,6 +20,13 @@ providers:
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: faiss
provider_type: inline::faiss
config:
@ -103,7 +110,8 @@ models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}

Some files were not shown because too many files have changed in this diff Show more