mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Merge remote-tracking branch 'upstream/main' into add_nvidia_safety_provider
Merging upstream changes
This commit is contained in:
commit
78b1105f5d
112 changed files with 5112 additions and 3313 deletions
|
@ -75,19 +75,20 @@ repos:
|
|||
# - id: markdown-link-check
|
||||
# args: ['--quiet']
|
||||
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: distro-codegen
|
||||
# name: Distribution Template Codegen
|
||||
# additional_dependencies:
|
||||
# - rich
|
||||
# - pydantic
|
||||
# entry: python -m llama_stack.scripts.distro_codegen
|
||||
# language: python
|
||||
# pass_filenames: false
|
||||
# require_serial: true
|
||||
# files: ^llama_stack/templates/.*$
|
||||
# stages: [manual]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: distro-codegen
|
||||
name: Distribution Template Codegen
|
||||
additional_dependencies:
|
||||
- rich
|
||||
- pydantic
|
||||
- uv==0.6.0
|
||||
entry: uv run python -m llama_stack.scripts.distro_codegen
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^llama_stack/templates/.*$
|
||||
files: ^llama_stack/providers/.*/inference/.*/models\.py$
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
|
@ -139,6 +139,16 @@ $ make html
|
|||
$ uv run sphinx-autobuild source build/html
|
||||
```
|
||||
|
||||
### Update API Documentation
|
||||
|
||||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||
|
||||
```bash
|
||||
$ uv sync --extra dev
|
||||
$ ./docs/openapi_generator/run_openapi_generator.sh
|
||||
```
|
||||
|
||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||
|
||||
## License
|
||||
By contributing to Llama, you agree that your contributions will be licensed
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -54,6 +55,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -88,6 +90,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -122,6 +125,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -157,6 +161,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -192,6 +197,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -228,6 +234,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -269,6 +276,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -306,6 +314,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -340,6 +349,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -373,6 +383,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -403,6 +414,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -438,6 +450,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -471,6 +484,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
@ -505,6 +519,7 @@
|
|||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
|
|
2608
docs/_static/llama-stack-spec.html
vendored
2608
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1728
docs/_static/llama-stack-spec.yaml
vendored
1728
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -1017,14 +1017,14 @@
|
|||
" \"content\": SYSTEM_PROMPT_TEMPLATE.format(subject=subset),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::mmmu\",\n",
|
||||
"client.benchmarks.register(\n",
|
||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" task_id=\"meta-reference::mmmu\",\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||
" input_rows=eval_rows,\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
" task_config={\n",
|
||||
|
@ -1196,14 +1196,14 @@
|
|||
" provider_id=\"together\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::simpleqa\",\n",
|
||||
"client.benchmarks.register(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" dataset_id=simpleqa_dataset_id,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" task_id=\"meta-reference::simpleqa\",\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" input_rows=eval_rows.rows,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
" task_config={\n",
|
||||
|
@ -1351,8 +1351,8 @@
|
|||
" \"enable_session_persistence\": False,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" task_id=\"meta-reference::simpleqa\",\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" input_rows=eval_rows.rows,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
" task_config={\n",
|
||||
|
|
|
@ -3,7 +3,7 @@ The RFC Specification (OpenAPI format) is generated from the set of API endpoint
|
|||
Please install the following packages before running the script:
|
||||
|
||||
```
|
||||
pip install python-openapi json-strong-typing fire PyYAML llama-models
|
||||
pip install fire PyYAML llama-models
|
||||
```
|
||||
|
||||
Then simply run `sh run_openapi_generator.sh`
|
||||
|
|
|
@ -477,6 +477,7 @@ class Generator:
|
|||
"SyntheticDataGeneration",
|
||||
"PostTraining",
|
||||
"BatchInference",
|
||||
"Files",
|
||||
]:
|
||||
op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
|
||||
print(op.defining_class.__name__)
|
||||
|
@ -520,8 +521,30 @@ class Generator:
|
|||
# parameters passed anywhere
|
||||
parameters = path_parameters + query_parameters
|
||||
|
||||
# data passed in payload
|
||||
if op.request_params:
|
||||
webmethod = getattr(op.func_ref, "__webmethod__", None)
|
||||
raw_bytes_request_body = False
|
||||
if webmethod:
|
||||
raw_bytes_request_body = getattr(webmethod, "raw_bytes_request_body", False)
|
||||
|
||||
# data passed in request body as raw bytes cannot have request parameters
|
||||
if raw_bytes_request_body and op.request_params:
|
||||
raise ValueError("Cannot have both raw bytes request body and request parameters")
|
||||
|
||||
# data passed in request body as raw bytes
|
||||
if raw_bytes_request_body:
|
||||
requestBody = RequestBody(
|
||||
content={
|
||||
"application/octet-stream": {
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"format": "binary",
|
||||
}
|
||||
}
|
||||
},
|
||||
required=True,
|
||||
)
|
||||
# data passed in payload as JSON and mapped to request parameters
|
||||
elif op.request_params:
|
||||
builder = ContentBuilder(self.schema_builder)
|
||||
first = next(iter(op.request_params))
|
||||
request_name, request_type = first
|
||||
|
|
|
@ -150,7 +150,14 @@ def _get_endpoint_functions(
|
|||
|
||||
print(f"Processing {colored(func_name, 'white')}...")
|
||||
operation_name = func_name
|
||||
if operation_name.startswith("get_") or operation_name.endswith("/get"):
|
||||
|
||||
if webmethod.method == "GET":
|
||||
prefix = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
prefix = "delete"
|
||||
elif webmethod.method == "POST":
|
||||
prefix = "post"
|
||||
elif operation_name.startswith("get_") or operation_name.endswith("/get"):
|
||||
prefix = "get"
|
||||
elif (
|
||||
operation_name.startswith("delete_")
|
||||
|
@ -160,13 +167,8 @@ def _get_endpoint_functions(
|
|||
):
|
||||
prefix = "delete"
|
||||
else:
|
||||
if webmethod.method == "GET":
|
||||
prefix = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
prefix = "delete"
|
||||
else:
|
||||
# by default everything else is a POST
|
||||
prefix = "post"
|
||||
# by default everything else is a POST
|
||||
prefix = "post"
|
||||
|
||||
yield prefix, operation_name, func_name, func_ref
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ class MediaType:
|
|||
|
||||
@dataclass
|
||||
class RequestBody:
|
||||
content: Dict[str, MediaType]
|
||||
content: Dict[str, MediaType | Dict[str, Any]]
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
sphinx
|
||||
sphinx==8.1.3
|
||||
myst-parser
|
||||
linkify
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
|
|
|
@ -122,3 +122,20 @@ response = agent.create_turn(
|
|||
session_id=session_id,
|
||||
)
|
||||
```
|
||||
|
||||
### Unregistering Vector DBs
|
||||
|
||||
If you need to clean up and unregister vector databases, you can do so as follows:
|
||||
|
||||
```python
|
||||
# Unregister a specified vector database
|
||||
vector_db_id = "my_vector_db_id"
|
||||
print(f"Unregistering vector database: {vector_db_id}")
|
||||
client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
|
||||
|
||||
# Unregister all vector databases
|
||||
for vector_db_id in client.vector_dbs.list():
|
||||
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
||||
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
||||
```
|
||||
|
|
|
@ -60,6 +60,11 @@ Features:
|
|||
- Disabled dangerous system operations
|
||||
- Configurable execution timeouts
|
||||
|
||||
> ⚠️ Important: The code interpreter tool can operate in a controlled enviroment locally or on Podman containers. To ensure proper functionality in containerised environments:
|
||||
> - The container requires privileged access (e.g., --privileged).
|
||||
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
|
||||
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
|
||||
|
||||
#### WolframAlpha
|
||||
|
||||
The WolframAlpha tool provides access to computational knowledge through the WolframAlpha API.
|
||||
|
@ -103,7 +108,7 @@ Features:
|
|||
|
||||
MCP tools are special tools that can interact with llama stack over model context protocol. These tools are dynamically discovered from an MCP endpoint and can be used to extend the agent's capabilities.
|
||||
|
||||
Refer to https://github.com/modelcontextprotocol/server for available MCP servers.
|
||||
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
|
||||
|
||||
```python
|
||||
# Register MCP tools
|
||||
|
@ -191,3 +196,36 @@ all_tools = client.tools.list_tools()
|
|||
# List tools in a specific group
|
||||
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||
```
|
||||
|
||||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
|
||||
# Configure the AI agent with necessary parameters
|
||||
agent_config = AgentConfig(
|
||||
name="code-interpreter",
|
||||
description="A code interpreter agent for executing Python code snippets",
|
||||
instructions="""
|
||||
You are a highly reliable, concise, and precise assistant.
|
||||
Always show the generated code, never generate your own code, and never anticipate results.
|
||||
""",
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
toolgroups=["builtin::code_interpreter"],
|
||||
max_infer_iters=5,
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(client, agent_config)
|
||||
|
||||
# Start a session
|
||||
session_id = agent.create_session("tool_session")
|
||||
|
||||
# Send a query to the AI agent for code execution
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": "Run this code: print(3 ** 4 - 5 * 2)"}],
|
||||
session_id=session_id,
|
||||
)
|
||||
```
|
||||
|
|
|
@ -13,6 +13,7 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
|
|||
- **DatasetIO**: interface with datasets and data loaders
|
||||
- **Scoring**: evaluate outputs of the system
|
||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring
|
||||
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
|
||||
- **Telemetry**: collect telemetry data from the system
|
||||
|
||||
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
||||
|
@ -41,6 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi
|
|||
- **Safety** is associated with `Shield` resources.
|
||||
- **Tool Runtime** is associated with `ToolGroup` resources.
|
||||
- **DatasetIO** is associated with `Dataset` resources.
|
||||
- **VectorIO** is associated with `VectorDB` resources.
|
||||
- **Scoring** is associated with `ScoringFunction` resources.
|
||||
- **Eval** is associated with `Model` and `Benchmark` resources.
|
||||
|
||||
|
|
|
@ -93,9 +93,10 @@ html_theme_options = {
|
|||
|
||||
html_static_path = ["../_static"]
|
||||
# html_logo = "../_static/llama-stack-logo.png"
|
||||
html_style = "../_static/css/my_theme.css"
|
||||
# html_style = "../_static/css/my_theme.css"
|
||||
|
||||
def setup(app):
|
||||
app.add_css_file("css/my_theme.css")
|
||||
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
|
||||
url = f"https://hub.docker.com/r/llamastack/{text}"
|
||||
node = nodes.reference(rawtext, text, refuri=url, **options)
|
||||
|
|
|
@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
|
|||
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
|
||||
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
|
||||
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
|
||||
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`llama_stack/scripts/distro_codegen.py` if necessary.
|
||||
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`llama_stack/scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
|
||||
|
||||
|
||||
Here are some example PRs to help you get started:
|
||||
|
|
|
@ -10,7 +10,7 @@ conda_env: ollama
|
|||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- memory
|
||||
- vector_io
|
||||
- safety
|
||||
- telemetry
|
||||
providers:
|
||||
|
@ -19,7 +19,7 @@ providers:
|
|||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
memory:
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
|
|
|
@ -61,7 +61,8 @@ docker run \
|
|||
--port $LLAMA_STACK_PORT \
|
||||
--env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
--env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN
|
||||
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN \
|
||||
--env AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
@ -72,5 +73,6 @@ llama stack run ./run.yaml \
|
|||
--port $LLAMA_STACK_PORT \
|
||||
--env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
--env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN
|
||||
--env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN \
|
||||
--env AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION
|
||||
```
|
||||
|
|
|
@ -47,6 +47,7 @@ The following models are available by default:
|
|||
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
|
||||
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
|
||||
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
|
||||
- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -130,7 +130,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -46,6 +46,8 @@ The following models are available by default:
|
|||
- `meta-llama/Llama-3.3-70B-Instruct`
|
||||
- `meta-llama/Llama-Guard-3-8B`
|
||||
- `meta-llama/Llama-Guard-3-11B-Vision`
|
||||
- `togethercomputer/m2-bert-80M-8k-retrieval`
|
||||
- `togethercomputer/m2-bert-80M-32k-retrieval`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -214,10 +214,16 @@ documents = [
|
|||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
vector_providers = [
|
||||
provider for provider in client.providers.list() if provider.api == "vector_io"
|
||||
]
|
||||
provider_id = vector_providers[0].provider_id # Use the first available vector provider
|
||||
|
||||
# Register a vector database
|
||||
vector_db_id = f"test-vector-db-{uuid.uuid4().hex}"
|
||||
client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
provider_id=provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
|
|
@ -42,10 +42,10 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/toolgroups">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">list</a>() -> <a href="./src/llama_stack_client/types/toolgroup_list_response.py">ToolgroupListResponse</a></code>
|
||||
- <code title="get /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">get</a>(toolgroup_id) -> <a href="./src/llama_stack_client/types/tool_group.py">ToolGroup</a></code>
|
||||
- <code title="post /v1/toolgroups">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">register</a>(\*\*<a href="src/llama_stack_client/types/toolgroup_register_params.py">params</a>) -> None</code>
|
||||
- <code title="delete /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="./src/llama_stack_client/resources/toolgroups.py">unregister</a>(toolgroup_id) -> None</code>
|
||||
- <code title="get /v1/toolgroups">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/toolgroup_list_response.py">ToolgroupListResponse</a></code>
|
||||
- <code title="get /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">get</a>(toolgroup_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_group.py">ToolGroup</a></code>
|
||||
- <code title="post /v1/toolgroups">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/toolgroup_register_params.py">params</a>) -> None</code>
|
||||
- <code title="delete /v1/toolgroups/{toolgroup_id}">client.toolgroups.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/toolgroups.py">unregister</a>(toolgroup_id) -> None</code>
|
||||
|
||||
## Tools
|
||||
|
||||
|
@ -57,8 +57,8 @@ from llama_stack_client.types import ListToolsResponse, Tool, ToolListResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/tools">client.tools.<a href="./src/llama_stack_client/resources/tools.py">list</a>(\*\*<a href="src/llama_stack_client/types/tool_list_params.py">params</a>) -> <a href="./src/llama_stack_client/types/tool_list_response.py">ToolListResponse</a></code>
|
||||
- <code title="get /v1/tools/{tool_name}">client.tools.<a href="./src/llama_stack_client/resources/tools.py">get</a>(tool_name) -> <a href="./src/llama_stack_client/types/tool.py">Tool</a></code>
|
||||
- <code title="get /v1/tools">client.tools.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tools.py">list</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_list_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_list_response.py">ToolListResponse</a></code>
|
||||
- <code title="get /v1/tools/{tool_name}">client.tools.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tools.py">get</a>(tool_name) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool.py">Tool</a></code>
|
||||
|
||||
## ToolRuntime
|
||||
|
||||
|
@ -70,15 +70,15 @@ from llama_stack_client.types import ToolDef, ToolInvocationResult
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/tool-runtime/invoke">client.tool_runtime.<a href="./src/llama_stack_client/resources/tool_runtime/tool_runtime.py">invoke_tool</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime_invoke_tool_params.py">params</a>) -> <a href="./src/llama_stack_client/types/tool_invocation_result.py">ToolInvocationResult</a></code>
|
||||
- <code title="get /v1/tool-runtime/list-tools">client.tool_runtime.<a href="./src/llama_stack_client/resources/tool_runtime/tool_runtime.py">list_tools</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime_list_tools_params.py">params</a>) -> <a href="./src/llama_stack_client/types/tool_def.py">JSONLDecoder[ToolDef]</a></code>
|
||||
- <code title="post /v1/tool-runtime/invoke">client.tool_runtime.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/tool_runtime.py">invoke_tool</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime_invoke_tool_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_invocation_result.py">ToolInvocationResult</a></code>
|
||||
- <code title="get /v1/tool-runtime/list-tools">client.tool_runtime.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/tool_runtime.py">list_tools</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime_list_tools_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_def.py">JSONLDecoder[ToolDef]</a></code>
|
||||
|
||||
### RagTool
|
||||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/tool-runtime/rag-tool/insert">client.tool_runtime.rag_tool.<a href="./src/llama_stack_client/resources/tool_runtime/rag_tool.py">insert</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime/rag_tool_insert_params.py">params</a>) -> None</code>
|
||||
- <code title="post /v1/tool-runtime/rag-tool/query">client.tool_runtime.rag_tool.<a href="./src/llama_stack_client/resources/tool_runtime/rag_tool.py">query</a>(\*\*<a href="src/llama_stack_client/types/tool_runtime/rag_tool_query_params.py">params</a>) -> <a href="./src/llama_stack_client/types/shared/query_result.py">QueryResult</a></code>
|
||||
- <code title="post /v1/tool-runtime/rag-tool/insert">client.tool_runtime.rag_tool.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/rag_tool.py">insert</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime/rag_tool_insert_params.py">params</a>) -> None</code>
|
||||
- <code title="post /v1/tool-runtime/rag-tool/query">client.tool_runtime.rag_tool.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/tool_runtime/rag_tool.py">query</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/tool_runtime/rag_tool_query_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shared/query_result.py">QueryResult</a></code>
|
||||
|
||||
## Agents
|
||||
|
||||
|
@ -97,8 +97,8 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/agents">client.agents.<a href="./src/llama_stack_client/resources/agents/agents.py">create</a>(\*\*<a href="src/llama_stack_client/types/agent_create_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agent_create_response.py">AgentCreateResponse</a></code>
|
||||
- <code title="delete /v1/agents/{agent_id}">client.agents.<a href="./src/llama_stack_client/resources/agents/agents.py">delete</a>(agent_id) -> None</code>
|
||||
- <code title="post /v1/agents">client.agents.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/agents.py">create</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agent_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agent_create_response.py">AgentCreateResponse</a></code>
|
||||
- <code title="delete /v1/agents/{agent_id}">client.agents.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/agents.py">delete</a>(agent_id) -> None</code>
|
||||
|
||||
### Session
|
||||
|
||||
|
@ -110,9 +110,9 @@ from llama_stack_client.types.agents import Session, SessionCreateResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/agents/{agent_id}/session">client.agents.session.<a href="./src/llama_stack_client/resources/agents/session.py">create</a>(agent_id, \*\*<a href="src/llama_stack_client/types/agents/session_create_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agents/session_create_response.py">SessionCreateResponse</a></code>
|
||||
- <code title="get /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="./src/llama_stack_client/resources/agents/session.py">retrieve</a>(session_id, \*, agent_id, \*\*<a href="src/llama_stack_client/types/agents/session_retrieve_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agents/session.py">Session</a></code>
|
||||
- <code title="delete /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="./src/llama_stack_client/resources/agents/session.py">delete</a>(session_id, \*, agent_id) -> None</code>
|
||||
- <code title="post /v1/agents/{agent_id}/session">client.agents.session.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/session.py">create</a>(agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session_create_response.py">SessionCreateResponse</a></code>
|
||||
- <code title="get /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/session.py">retrieve</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session_retrieve_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/session.py">Session</a></code>
|
||||
- <code title="delete /v1/agents/{agent_id}/session/{session_id}">client.agents.session.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/session.py">delete</a>(session_id, \*, agent_id) -> None</code>
|
||||
|
||||
### Steps
|
||||
|
||||
|
@ -124,7 +124,7 @@ from llama_stack_client.types.agents import StepRetrieveResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}">client.agents.steps.<a href="./src/llama_stack_client/resources/agents/steps.py">retrieve</a>(step_id, \*, agent_id, session_id, turn_id) -> <a href="./src/llama_stack_client/types/agents/step_retrieve_response.py">StepRetrieveResponse</a></code>
|
||||
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}">client.agents.steps.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/steps.py">retrieve</a>(step_id, \*, agent_id, session_id, turn_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/step_retrieve_response.py">StepRetrieveResponse</a></code>
|
||||
|
||||
### Turn
|
||||
|
||||
|
@ -136,8 +136,8 @@ from llama_stack_client.types.agents import Turn, TurnCreateResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="./src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="./src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code>
|
||||
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="./src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="./src/llama_stack_client/types/agents/turn.py">Turn</a></code>
|
||||
- <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code>
|
||||
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn.py">Turn</a></code>
|
||||
|
||||
## BatchInference
|
||||
|
||||
|
@ -149,8 +149,8 @@ from llama_stack_client.types import BatchInferenceChatCompletionResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/batch-inference/chat-completion">client.batch_inference.<a href="./src/llama_stack_client/resources/batch_inference.py">chat_completion</a>(\*\*<a href="src/llama_stack_client/types/batch_inference_chat_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/batch_inference_chat_completion_response.py">BatchInferenceChatCompletionResponse</a></code>
|
||||
- <code title="post /v1/batch-inference/completion">client.batch_inference.<a href="./src/llama_stack_client/resources/batch_inference.py">completion</a>(\*\*<a href="src/llama_stack_client/types/batch_inference_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/shared/batch_completion.py">BatchCompletion</a></code>
|
||||
- <code title="post /v1/batch-inference/chat-completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">chat_completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_response.py">BatchInferenceChatCompletionResponse</a></code>
|
||||
- <code title="post /v1/batch-inference/completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shared/batch_completion.py">BatchCompletion</a></code>
|
||||
|
||||
## Datasets
|
||||
|
||||
|
@ -166,10 +166,10 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/datasets/{dataset_id}">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">retrieve</a>(dataset_id) -> <a href="./src/llama_stack_client/types/dataset_retrieve_response.py">Optional[DatasetRetrieveResponse]</a></code>
|
||||
- <code title="get /v1/datasets">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">list</a>() -> <a href="./src/llama_stack_client/types/dataset_list_response.py">DatasetListResponse</a></code>
|
||||
- <code title="post /v1/datasets">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">register</a>(\*\*<a href="src/llama_stack_client/types/dataset_register_params.py">params</a>) -> None</code>
|
||||
- <code title="delete /v1/datasets/{dataset_id}">client.datasets.<a href="./src/llama_stack_client/resources/datasets.py">unregister</a>(dataset_id) -> None</code>
|
||||
- <code title="get /v1/datasets/{dataset_id}">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">retrieve</a>(dataset_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/dataset_retrieve_response.py">Optional[DatasetRetrieveResponse]</a></code>
|
||||
- <code title="get /v1/datasets">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/dataset_list_response.py">DatasetListResponse</a></code>
|
||||
- <code title="post /v1/datasets">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/dataset_register_params.py">params</a>) -> None</code>
|
||||
- <code title="delete /v1/datasets/{dataset_id}">client.datasets.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasets.py">unregister</a>(dataset_id) -> None</code>
|
||||
|
||||
## Eval
|
||||
|
||||
|
@ -181,8 +181,8 @@ from llama_stack_client.types import EvaluateResponse, Job
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/eval/tasks/{benchmark_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="post /v1/eval/tasks/{benchmark_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
|
||||
- <code title="post /v1/eval/tasks/{benchmark_id}/evaluations">client.eval.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(benchmark_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="post /v1/eval/tasks/{benchmark_id}/jobs">client.eval.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/eval.py">run_eval</a>(benchmark_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/job.py">Job</a></code>
|
||||
|
||||
### Jobs
|
||||
|
||||
|
@ -194,9 +194,9 @@ from llama_stack_client.types.eval import JobStatusResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, benchmark_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="delete /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, benchmark_id) -> None</code>
|
||||
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, benchmark_id) -> Optional[JobStatusResponse]</code>
|
||||
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}/result">client.eval.jobs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, benchmark_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
|
||||
- <code title="delete /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, benchmark_id) -> None</code>
|
||||
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, benchmark_id) -> Optional[JobStatusResponse]</code>
|
||||
|
||||
## Inspect
|
||||
|
||||
|
@ -208,8 +208,8 @@ from llama_stack_client.types import HealthInfo, ProviderInfo, RouteInfo, Versio
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/health">client.inspect.<a href="./src/llama_stack_client/resources/inspect.py">health</a>() -> <a href="./src/llama_stack_client/types/health_info.py">HealthInfo</a></code>
|
||||
- <code title="get /v1/version">client.inspect.<a href="./src/llama_stack_client/resources/inspect.py">version</a>() -> <a href="./src/llama_stack_client/types/version_info.py">VersionInfo</a></code>
|
||||
- <code title="get /v1/health">client.inspect.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inspect.py">health</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/health_info.py">HealthInfo</a></code>
|
||||
- <code title="get /v1/version">client.inspect.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inspect.py">version</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/version_info.py">VersionInfo</a></code>
|
||||
|
||||
## Inference
|
||||
|
||||
|
@ -227,9 +227,9 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/inference/chat-completion">client.inference.<a href="./src/llama_stack_client/resources/inference.py">chat_completion</a>(\*\*<a href="src/llama_stack_client/types/inference_chat_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/inference_chat_completion_response.py">InferenceChatCompletionResponse</a></code>
|
||||
- <code title="post /v1/inference/completion">client.inference.<a href="./src/llama_stack_client/resources/inference.py">completion</a>(\*\*<a href="src/llama_stack_client/types/inference_completion_params.py">params</a>) -> <a href="./src/llama_stack_client/types/inference_completion_response.py">InferenceCompletionResponse</a></code>
|
||||
- <code title="post /v1/inference/embeddings">client.inference.<a href="./src/llama_stack_client/resources/inference.py">embeddings</a>(\*\*<a href="src/llama_stack_client/types/inference_embeddings_params.py">params</a>) -> <a href="./src/llama_stack_client/types/embeddings_response.py">EmbeddingsResponse</a></code>
|
||||
- <code title="post /v1/inference/chat-completion">client.inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inference.py">chat_completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_chat_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_chat_completion_response.py">InferenceChatCompletionResponse</a></code>
|
||||
- <code title="post /v1/inference/completion">client.inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inference.py">completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_completion_response.py">InferenceCompletionResponse</a></code>
|
||||
- <code title="post /v1/inference/embeddings">client.inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/inference.py">embeddings</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/inference_embeddings_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/embeddings_response.py">EmbeddingsResponse</a></code>
|
||||
|
||||
## VectorIo
|
||||
|
||||
|
@ -241,8 +241,8 @@ from llama_stack_client.types import QueryChunksResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/vector-io/insert">client.vector_io.<a href="./src/llama_stack_client/resources/vector_io.py">insert</a>(\*\*<a href="src/llama_stack_client/types/vector_io_insert_params.py">params</a>) -> None</code>
|
||||
- <code title="post /v1/vector-io/query">client.vector_io.<a href="./src/llama_stack_client/resources/vector_io.py">query</a>(\*\*<a href="src/llama_stack_client/types/vector_io_query_params.py">params</a>) -> <a href="./src/llama_stack_client/types/query_chunks_response.py">QueryChunksResponse</a></code>
|
||||
- <code title="post /v1/vector-io/insert">client.vector_io.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_io.py">insert</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_io_insert_params.py">params</a>) -> None</code>
|
||||
- <code title="post /v1/vector-io/query">client.vector_io.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_io.py">query</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_io_query_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/query_chunks_response.py">QueryChunksResponse</a></code>
|
||||
|
||||
## VectorDBs
|
||||
|
||||
|
@ -259,10 +259,10 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">retrieve</a>(vector_db_id) -> <a href="./src/llama_stack_client/types/vector_db_retrieve_response.py">Optional[VectorDBRetrieveResponse]</a></code>
|
||||
- <code title="get /v1/vector-dbs">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">list</a>() -> <a href="./src/llama_stack_client/types/vector_db_list_response.py">VectorDBListResponse</a></code>
|
||||
- <code title="post /v1/vector-dbs">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">register</a>(\*\*<a href="src/llama_stack_client/types/vector_db_register_params.py">params</a>) -> <a href="./src/llama_stack_client/types/vector_db_register_response.py">VectorDBRegisterResponse</a></code>
|
||||
- <code title="delete /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="./src/llama_stack_client/resources/vector_dbs.py">unregister</a>(vector_db_id) -> None</code>
|
||||
- <code title="get /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">retrieve</a>(vector_db_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_retrieve_response.py">Optional[VectorDBRetrieveResponse]</a></code>
|
||||
- <code title="get /v1/vector-dbs">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_list_response.py">VectorDBListResponse</a></code>
|
||||
- <code title="post /v1/vector-dbs">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_register_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/vector_db_register_response.py">VectorDBRegisterResponse</a></code>
|
||||
- <code title="delete /v1/vector-dbs/{vector_db_id}">client.vector_dbs.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/vector_dbs.py">unregister</a>(vector_db_id) -> None</code>
|
||||
|
||||
## Models
|
||||
|
||||
|
@ -274,10 +274,10 @@ from llama_stack_client.types import ListModelsResponse, Model, ModelListRespons
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/models/{model_id}">client.models.<a href="./src/llama_stack_client/resources/models.py">retrieve</a>(model_id) -> <a href="./src/llama_stack_client/types/model.py">Optional[Model]</a></code>
|
||||
- <code title="get /v1/models">client.models.<a href="./src/llama_stack_client/resources/models.py">list</a>() -> <a href="./src/llama_stack_client/types/model_list_response.py">ModelListResponse</a></code>
|
||||
- <code title="post /v1/models">client.models.<a href="./src/llama_stack_client/resources/models.py">register</a>(\*\*<a href="src/llama_stack_client/types/model_register_params.py">params</a>) -> <a href="./src/llama_stack_client/types/model.py">Model</a></code>
|
||||
- <code title="delete /v1/models/{model_id}">client.models.<a href="./src/llama_stack_client/resources/models.py">unregister</a>(model_id) -> None</code>
|
||||
- <code title="get /v1/models/{model_id}">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">retrieve</a>(model_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model.py">Optional[Model]</a></code>
|
||||
- <code title="get /v1/models">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model_list_response.py">ModelListResponse</a></code>
|
||||
- <code title="post /v1/models">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model_register_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/model.py">Model</a></code>
|
||||
- <code title="delete /v1/models/{model_id}">client.models.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/models.py">unregister</a>(model_id) -> None</code>
|
||||
|
||||
## PostTraining
|
||||
|
||||
|
@ -289,8 +289,8 @@ from llama_stack_client.types import ListPostTrainingJobsResponse, PostTrainingJ
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/post-training/preference-optimize">client.post_training.<a href="./src/llama_stack_client/resources/post_training/post_training.py">preference_optimize</a>(\*\*<a href="src/llama_stack_client/types/post_training_preference_optimize_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
|
||||
- <code title="post /v1/post-training/supervised-fine-tune">client.post_training.<a href="./src/llama_stack_client/resources/post_training/post_training.py">supervised_fine_tune</a>(\*\*<a href="src/llama_stack_client/types/post_training_supervised_fine_tune_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
|
||||
- <code title="post /v1/post-training/preference-optimize">client.post_training.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/post_training.py">preference_optimize</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_preference_optimize_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
|
||||
- <code title="post /v1/post-training/supervised-fine-tune">client.post_training.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/post_training.py">supervised_fine_tune</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training_job.py">PostTrainingJob</a></code>
|
||||
|
||||
### Job
|
||||
|
||||
|
@ -306,10 +306,10 @@ from llama_stack_client.types.post_training import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/post-training/jobs">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">list</a>() -> <a href="./src/llama_stack_client/types/post_training/job_list_response.py">JobListResponse</a></code>
|
||||
- <code title="get /v1/post-training/job/artifacts">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">artifacts</a>(\*\*<a href="src/llama_stack_client/types/post_training/job_artifacts_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training/job_artifacts_response.py">Optional[JobArtifactsResponse]</a></code>
|
||||
- <code title="post /v1/post-training/job/cancel">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">cancel</a>(\*\*<a href="src/llama_stack_client/types/post_training/job_cancel_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/post-training/job/status">client.post_training.job.<a href="./src/llama_stack_client/resources/post_training/job.py">status</a>(\*\*<a href="src/llama_stack_client/types/post_training/job_status_params.py">params</a>) -> <a href="./src/llama_stack_client/types/post_training/job_status_response.py">Optional[JobStatusResponse]</a></code>
|
||||
- <code title="get /v1/post-training/jobs">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_list_response.py">JobListResponse</a></code>
|
||||
- <code title="get /v1/post-training/job/artifacts">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">artifacts</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_artifacts_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_artifacts_response.py">Optional[JobArtifactsResponse]</a></code>
|
||||
- <code title="post /v1/post-training/job/cancel">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">cancel</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_cancel_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/post-training/job/status">client.post_training.job.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/post_training/job.py">status</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_status_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/post_training/job_status_response.py">Optional[JobStatusResponse]</a></code>
|
||||
|
||||
## Providers
|
||||
|
||||
|
@ -321,7 +321,7 @@ from llama_stack_client.types import ListProvidersResponse, ProviderListResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/inspect/providers">client.providers.<a href="./src/llama_stack_client/resources/providers.py">list</a>() -> <a href="./src/llama_stack_client/types/provider_list_response.py">ProviderListResponse</a></code>
|
||||
- <code title="get /v1/inspect/providers">client.providers.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/providers.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/provider_list_response.py">ProviderListResponse</a></code>
|
||||
|
||||
## Routes
|
||||
|
||||
|
@ -333,7 +333,7 @@ from llama_stack_client.types import ListRoutesResponse, RouteListResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/inspect/routes">client.routes.<a href="./src/llama_stack_client/resources/routes.py">list</a>() -> <a href="./src/llama_stack_client/types/route_list_response.py">RouteListResponse</a></code>
|
||||
- <code title="get /v1/inspect/routes">client.routes.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/routes.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/route_list_response.py">RouteListResponse</a></code>
|
||||
|
||||
## Safety
|
||||
|
||||
|
@ -345,7 +345,7 @@ from llama_stack_client.types import RunShieldResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/safety/run-shield">client.safety.<a href="./src/llama_stack_client/resources/safety.py">run_shield</a>(\*\*<a href="src/llama_stack_client/types/safety_run_shield_params.py">params</a>) -> <a href="./src/llama_stack_client/types/run_shield_response.py">RunShieldResponse</a></code>
|
||||
- <code title="post /v1/safety/run-shield">client.safety.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/safety.py">run_shield</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/safety_run_shield_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/run_shield_response.py">RunShieldResponse</a></code>
|
||||
|
||||
## Shields
|
||||
|
||||
|
@ -357,9 +357,9 @@ from llama_stack_client.types import ListShieldsResponse, Shield, ShieldListResp
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/shields/{identifier}">client.shields.<a href="./src/llama_stack_client/resources/shields.py">retrieve</a>(identifier) -> <a href="./src/llama_stack_client/types/shield.py">Optional[Shield]</a></code>
|
||||
- <code title="get /v1/shields">client.shields.<a href="./src/llama_stack_client/resources/shields.py">list</a>() -> <a href="./src/llama_stack_client/types/shield_list_response.py">ShieldListResponse</a></code>
|
||||
- <code title="post /v1/shields">client.shields.<a href="./src/llama_stack_client/resources/shields.py">register</a>(\*\*<a href="src/llama_stack_client/types/shield_register_params.py">params</a>) -> <a href="./src/llama_stack_client/types/shield.py">Shield</a></code>
|
||||
- <code title="get /v1/shields/{identifier}">client.shields.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/shields.py">retrieve</a>(identifier) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield.py">Optional[Shield]</a></code>
|
||||
- <code title="get /v1/shields">client.shields.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/shields.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield_list_response.py">ShieldListResponse</a></code>
|
||||
- <code title="post /v1/shields">client.shields.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/shields.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield_register_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shield.py">Shield</a></code>
|
||||
|
||||
## SyntheticDataGeneration
|
||||
|
||||
|
@ -371,7 +371,7 @@ from llama_stack_client.types import SyntheticDataGenerationResponse
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/synthetic-data-generation/generate">client.synthetic_data_generation.<a href="./src/llama_stack_client/resources/synthetic_data_generation.py">generate</a>(\*\*<a href="src/llama_stack_client/types/synthetic_data_generation_generate_params.py">params</a>) -> <a href="./src/llama_stack_client/types/synthetic_data_generation_response.py">SyntheticDataGenerationResponse</a></code>
|
||||
- <code title="post /v1/synthetic-data-generation/generate">client.synthetic_data_generation.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/synthetic_data_generation.py">generate</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/synthetic_data_generation_generate_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/synthetic_data_generation_response.py">SyntheticDataGenerationResponse</a></code>
|
||||
|
||||
## Telemetry
|
||||
|
||||
|
@ -391,13 +391,13 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/telemetry/traces/{trace_id}/spans/{span_id}">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">get_span</a>(span_id, \*, trace_id) -> <a href="./src/llama_stack_client/types/telemetry_get_span_response.py">TelemetryGetSpanResponse</a></code>
|
||||
- <code title="get /v1/telemetry/spans/{span_id}/tree">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">get_span_tree</a>(span_id, \*\*<a href="src/llama_stack_client/types/telemetry_get_span_tree_params.py">params</a>) -> <a href="./src/llama_stack_client/types/telemetry_get_span_tree_response.py">TelemetryGetSpanTreeResponse</a></code>
|
||||
- <code title="get /v1/telemetry/traces/{trace_id}">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">get_trace</a>(trace_id) -> <a href="./src/llama_stack_client/types/trace.py">Trace</a></code>
|
||||
- <code title="post /v1/telemetry/events">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">log_event</a>(\*\*<a href="src/llama_stack_client/types/telemetry_log_event_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/telemetry/spans">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">query_spans</a>(\*\*<a href="src/llama_stack_client/types/telemetry_query_spans_params.py">params</a>) -> <a href="./src/llama_stack_client/types/telemetry_query_spans_response.py">TelemetryQuerySpansResponse</a></code>
|
||||
- <code title="get /v1/telemetry/traces">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">query_traces</a>(\*\*<a href="src/llama_stack_client/types/telemetry_query_traces_params.py">params</a>) -> <a href="./src/llama_stack_client/types/telemetry_query_traces_response.py">TelemetryQueryTracesResponse</a></code>
|
||||
- <code title="post /v1/telemetry/spans/export">client.telemetry.<a href="./src/llama_stack_client/resources/telemetry.py">save_spans_to_dataset</a>(\*\*<a href="src/llama_stack_client/types/telemetry_save_spans_to_dataset_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/telemetry/traces/{trace_id}/spans/{span_id}">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">get_span</a>(span_id, \*, trace_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_get_span_response.py">TelemetryGetSpanResponse</a></code>
|
||||
- <code title="get /v1/telemetry/spans/{span_id}/tree">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">get_span_tree</a>(span_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_get_span_tree_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_get_span_tree_response.py">TelemetryGetSpanTreeResponse</a></code>
|
||||
- <code title="get /v1/telemetry/traces/{trace_id}">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">get_trace</a>(trace_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/trace.py">Trace</a></code>
|
||||
- <code title="post /v1/telemetry/events">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">log_event</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_log_event_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/telemetry/spans">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">query_spans</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_spans_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_spans_response.py">TelemetryQuerySpansResponse</a></code>
|
||||
- <code title="get /v1/telemetry/traces">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">query_traces</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_traces_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_query_traces_response.py">TelemetryQueryTracesResponse</a></code>
|
||||
- <code title="post /v1/telemetry/spans/export">client.telemetry.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/telemetry.py">save_spans_to_dataset</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/telemetry_save_spans_to_dataset_params.py">params</a>) -> None</code>
|
||||
|
||||
## Datasetio
|
||||
|
||||
|
@ -409,8 +409,8 @@ from llama_stack_client.types import PaginatedRowsResult
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/datasetio/rows">client.datasetio.<a href="./src/llama_stack_client/resources/datasetio.py">append_rows</a>(\*\*<a href="src/llama_stack_client/types/datasetio_append_rows_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/datasetio/rows">client.datasetio.<a href="./src/llama_stack_client/resources/datasetio.py">get_rows_paginated</a>(\*\*<a href="src/llama_stack_client/types/datasetio_get_rows_paginated_params.py">params</a>) -> <a href="./src/llama_stack_client/types/paginated_rows_result.py">PaginatedRowsResult</a></code>
|
||||
- <code title="post /v1/datasetio/rows">client.datasetio.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasetio.py">append_rows</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/datasetio_append_rows_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/datasetio/rows">client.datasetio.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/datasetio.py">get_rows_paginated</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/paginated_rows_result.py">PaginatedRowsResult</a></code>
|
||||
|
||||
## Scoring
|
||||
|
||||
|
@ -422,8 +422,8 @@ from llama_stack_client.types import ScoringScoreResponse, ScoringScoreBatchResp
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="post /v1/scoring/score">client.scoring.<a href="./src/llama_stack_client/resources/scoring.py">score</a>(\*\*<a href="src/llama_stack_client/types/scoring_score_params.py">params</a>) -> <a href="./src/llama_stack_client/types/scoring_score_response.py">ScoringScoreResponse</a></code>
|
||||
- <code title="post /v1/scoring/score-batch">client.scoring.<a href="./src/llama_stack_client/resources/scoring.py">score_batch</a>(\*\*<a href="src/llama_stack_client/types/scoring_score_batch_params.py">params</a>) -> <a href="./src/llama_stack_client/types/scoring_score_batch_response.py">ScoringScoreBatchResponse</a></code>
|
||||
- <code title="post /v1/scoring/score">client.scoring.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring.py">score</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_response.py">ScoringScoreResponse</a></code>
|
||||
- <code title="post /v1/scoring/score-batch">client.scoring.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring.py">score_batch</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_batch_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_score_batch_response.py">ScoringScoreBatchResponse</a></code>
|
||||
|
||||
## ScoringFunctions
|
||||
|
||||
|
@ -439,9 +439,9 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/scoring-functions/{scoring_fn_id}">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">retrieve</a>(scoring_fn_id) -> <a href="./src/llama_stack_client/types/scoring_fn.py">Optional[ScoringFn]</a></code>
|
||||
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="./src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
|
||||
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/scoring-functions/{scoring_fn_id}">client.scoring_functions.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring_functions.py">retrieve</a>(scoring_fn_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_fn.py">Optional[ScoringFn]</a></code>
|
||||
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
|
||||
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
|
||||
|
||||
## Benchmarks
|
||||
|
||||
|
@ -457,6 +457,6 @@ from llama_stack_client.types import (
|
|||
|
||||
Methods:
|
||||
|
||||
- <code title="get /v1/eval-tasks/{benchmark_id}">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">retrieve</a>(benchmark_id) -> <a href="./src/llama_stack_client/types/benchmark.py">Optional[Benchmark]</a></code>
|
||||
- <code title="get /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">list</a>() -> <a href="./src/llama_stack_client/types/benchmark_list_response.py">BenchmarkListResponse</a></code>
|
||||
- <code title="post /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">register</a>(\*\*<a href="src/llama_stack_client/types/benchmark_register_params.py">params</a>) -> None</code>
|
||||
- <code title="get /v1/eval-tasks/{benchmark_id}">client.benchmarks.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/benchmarks.py">retrieve</a>(benchmark_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/benchmark.py">Optional[Benchmark]</a></code>
|
||||
- <code title="get /v1/eval-tasks">client.benchmarks.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/benchmarks.py">list</a>() -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/benchmark_list_response.py">BenchmarkListResponse</a></code>
|
||||
- <code title="post /v1/eval-tasks">client.benchmarks.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/benchmarks.py">register</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/benchmark_register_params.py">params</a>) -> None</code>
|
||||
|
|
|
@ -64,23 +64,3 @@ class Benchmarks(Protocol):
|
|||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="GET")
|
||||
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: ...
|
||||
|
||||
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="POST")
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
|
|
@ -39,7 +39,6 @@ EvalCandidate = register_schema(
|
|||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
type: Literal["benchmark"] = "benchmark"
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
|
@ -84,28 +83,3 @@ class Eval(Protocol):
|
|||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST")
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET")
|
||||
async def DEPRECATED_job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE")
|
||||
async def DEPRECATED_job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET")
|
||||
async def DEPRECATED_job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
|
|
7
llama_stack/apis/files/__init__.py
Normal file
7
llama_stack/apis/files/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .files import * # noqa: F401 F403
|
174
llama_stack/apis/files/files.py
Normal file
174
llama_stack/apis/files/files.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FileUploadResponse(BaseModel):
|
||||
"""
|
||||
Response after initiating a file upload session.
|
||||
|
||||
:param id: ID of the upload session
|
||||
:param url: Upload URL for the file or file parts
|
||||
:param offset: Upload content offset
|
||||
:param size: Upload content size
|
||||
"""
|
||||
|
||||
id: str
|
||||
url: str
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BucketResponse(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBucketResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: List[BucketResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FileResponse(BaseModel):
|
||||
"""
|
||||
Response representing a file entry.
|
||||
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param mime_type: MIME type of the file
|
||||
:param url: Upload URL for the file contents
|
||||
:param bytes: Size of the file in bytes
|
||||
:param created_at: Timestamp of when the file was created
|
||||
"""
|
||||
|
||||
bucket: str
|
||||
key: str
|
||||
mime_type: str
|
||||
url: str
|
||||
bytes: int
|
||||
created_at: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListFileResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: List[FileResponse]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
@webmethod(route="/files", method="POST")
|
||||
async def create_upload_session(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
Create a new upload session for a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param mime_type: MIME type of the file
|
||||
:param size: File size in bytes
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/session:{upload_id}", method="POST", raw_bytes_request_body=True)
|
||||
async def upload_content_to_session(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> Optional[FileResponse]:
|
||||
"""
|
||||
Upload file content to an existing upload session.
|
||||
On the server, request body will have the raw bytes that are uploaded.
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/session:{upload_id}", method="GET")
|
||||
async def get_upload_session_info(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> Optional[FileUploadResponse]:
|
||||
"""
|
||||
Returns information about an existsing upload session
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files", method="GET")
|
||||
async def list_all_buckets(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListBucketResponse:
|
||||
"""
|
||||
List all buckets.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/{bucket}", method="GET")
|
||||
async def list_files_in_bucket(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListFileResponse:
|
||||
"""
|
||||
List all files in a bucket.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/{bucket}/{key:path}", method="GET")
|
||||
async def get_file(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> FileResponse:
|
||||
"""
|
||||
Get a file info identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/{bucket}/{key:path}", method="DELETE")
|
||||
async def delete_file(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> FileResponse:
|
||||
"""
|
||||
Delete a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
"""
|
||||
...
|
|
@ -216,7 +216,7 @@ class Telemetry(Protocol):
|
|||
@webmethod(route="/telemetry/events", method="POST")
|
||||
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="GET")
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
|
@ -231,7 +231,7 @@ class Telemetry(Protocol):
|
|||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
@ -239,7 +239,7 @@ class Telemetry(Protocol):
|
|||
max_depth: Optional[int] = None,
|
||||
) -> QuerySpanTreeResponse: ...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="GET")
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
|
|
|
@ -5,12 +5,44 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
||||
|
||||
def _get_model_size(model_dir):
|
||||
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
||||
|
||||
|
||||
def _run_model_list_downloaded_cmd() -> None:
|
||||
headers = ["Model", "Size", "Modified Time"]
|
||||
|
||||
rows = []
|
||||
for model in os.listdir(DEFAULT_CHECKPOINT_DIR):
|
||||
abs_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
|
||||
space_usage = _get_model_size(abs_path)
|
||||
model_size = f"{space_usage / (1024**3):.2f} GB"
|
||||
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
||||
rows.append(
|
||||
[
|
||||
model,
|
||||
model_size,
|
||||
modified_time,
|
||||
]
|
||||
)
|
||||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
||||
|
||||
class ModelList(Subcommand):
|
||||
"""List available llama models"""
|
||||
|
||||
|
@ -31,10 +63,18 @@ class ModelList(Subcommand):
|
|||
action="store_true",
|
||||
help="Show all models (not just defaults)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--downloaded",
|
||||
action="store_true",
|
||||
help="List the downloaded models",
|
||||
)
|
||||
|
||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku
|
||||
|
||||
if args.downloaded:
|
||||
return _run_model_list_downloaded_cmd()
|
||||
|
||||
headers = [
|
||||
"Model Descriptor(ID)",
|
||||
"Hugging Face Repo",
|
||||
|
|
|
@ -15,6 +15,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
|||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
|
@ -73,11 +74,16 @@ run() {
|
|||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
|
||||
if [ -n "$UV_SYSTEM_PYTHON" ]; then
|
||||
echo "Installing dependencies in system Python environment"
|
||||
else
|
||||
echo "Using virtual environment $env_name"
|
||||
uv venv "$env_name"
|
||||
# shellcheck source=/dev/null
|
||||
source "$env_name/bin/activate"
|
||||
fi
|
||||
|
||||
echo "Using virtual environment $env_name"
|
||||
uv venv "$env_name"
|
||||
# shellcheck source=/dev/null
|
||||
source "$env_name/bin/activate"
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
|
|
|
@ -10,12 +10,12 @@ cleanup() {
|
|||
set +x
|
||||
echo "Cleaning up..."
|
||||
conda deactivate
|
||||
conda env remove --name $envname -y
|
||||
conda env remove --name "$envname" -y
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
if [ -n "$ENVNAME" ]; then
|
||||
cleanup "$ENVNAME"
|
||||
fi
|
||||
exit 1
|
||||
}
|
||||
|
@ -23,8 +23,8 @@ handle_int() {
|
|||
handle_exit() {
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "\033[1;31mABORTING.\033[0m"
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
if [ -n "$ENVNAME" ]; then
|
||||
cleanup "$ENVNAME"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
@ -33,10 +33,14 @@ setup_cleanup_handlers() {
|
|||
trap handle_int INT
|
||||
trap handle_exit EXIT
|
||||
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||
eval "$__conda_setup"
|
||||
|
||||
conda deactivate
|
||||
if is_command_available conda; then
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||
eval "$__conda_setup"
|
||||
conda deactivate
|
||||
else
|
||||
echo "conda is not available"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# check if a command is present
|
||||
|
|
|
@ -411,48 +411,6 @@ class EvalRouter(Eval):
|
|||
job_id,
|
||||
)
|
||||
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
|
||||
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.evaluate_rows(
|
||||
benchmark_id=task_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def DEPRECATED_job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.job_status(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_result(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.job_result(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
|
|
|
@ -468,35 +468,6 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
await self.register_object(benchmark)
|
||||
|
||||
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.list_benchmarks()
|
||||
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.get_benchmark(eval_task_id)
|
||||
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.register_benchmark(
|
||||
benchmark_id=eval_task_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_benchmark_id=provider_benchmark_id,
|
||||
)
|
||||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
|
|
|
@ -481,6 +481,8 @@ def main():
|
|||
def extract_path_params(route: str) -> List[str]:
|
||||
segments = route.split("/")
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
# to handle path params like {param:path}
|
||||
params = [param.split(":")[0] for param in params]
|
||||
return params
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.benchmarks import Benchmarks
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
|
@ -63,6 +64,7 @@ class LlamaStack(
|
|||
ToolGroups,
|
||||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
Files,
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ def _run_with_pty_unix(command):
|
|||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
ctrl_c_pressed = False
|
||||
process = None
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
|
@ -98,7 +99,7 @@ def _run_with_pty_unix(command):
|
|||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
os.close(master)
|
||||
if process.poll() is None:
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
|
|
|
@ -234,45 +234,3 @@ class MetaReferenceEvalImpl(
|
|||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
return self.jobs[job_id]
|
||||
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
|
||||
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.evaluate_rows(
|
||||
benchmark_id=task_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def DEPRECATED_job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.job_status(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_result(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.job_result(benchmark_id=task_id, job_id=job_id)
|
||||
|
|
|
@ -46,7 +46,7 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
|
@ -116,7 +116,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
|
|
|
@ -9,7 +9,6 @@ import os
|
|||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
self.engine = None
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self):
|
||||
log.info("Initializing vLLM inference provider.")
|
||||
|
@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.config.model)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
||||
if stream:
|
||||
|
@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
cur = []
|
||||
async for chunk in results_generator:
|
||||
|
@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
output = chunk.outputs[-1]
|
||||
|
||||
new_tokens = output.token_ids[len(cur) :]
|
||||
text = self.formatter.tokenizer.decode(new_tokens)
|
||||
text = tokenizer.decode(new_tokens)
|
||||
cur.extend(new_tokens)
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=output.finish_reason,
|
||||
|
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||
|
|
|
@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl:
|
|||
self.datasets_api = datasets
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs_status = {}
|
||||
self.jobs_list = []
|
||||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def supervised_fine_tune(
|
||||
|
@ -54,9 +53,8 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
for job in self.jobs_list:
|
||||
if job_uuid == job.job_uuid:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
if job_uuid in self.jobs:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
|
@ -65,8 +63,8 @@ class TorchtunePostTrainingImpl:
|
|||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
self.jobs_list.append(post_training_job)
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
|
@ -100,8 +98,6 @@ class TorchtunePostTrainingImpl:
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.jobs_status[job_uuid] = job_status_response
|
||||
|
||||
return post_training_job
|
||||
|
||||
async def preference_optimize(
|
||||
|
@ -115,13 +111,11 @@ class TorchtunePostTrainingImpl:
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(data=self.jobs_list)
|
||||
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
if job_uuid in self.jobs_status:
|
||||
return self.jobs_status[job_uuid]
|
||||
return None
|
||||
return self.jobs.get(job_uuid, None)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
|
|
|
@ -64,8 +64,6 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
# This recipe only supports GPU training
|
||||
|
||||
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
||||
# - compile
|
||||
# - activation offloading
|
||||
|
@ -93,7 +91,7 @@ class LoraFinetuningSingleDevice:
|
|||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._device = torchtune_utils.get_device()
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
self.model_id = model
|
||||
|
||||
|
@ -231,6 +229,13 @@ class LoraFinetuningSingleDevice:
|
|||
# Used to ignore labels for loss computation
|
||||
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||
|
||||
def _log_memory_stats(self):
|
||||
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
|
||||
if self._device.type == "cpu":
|
||||
return
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
|
||||
async def _setup_model(
|
||||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
|
@ -293,8 +298,7 @@ class LoraFinetuningSingleDevice:
|
|||
# activation offloading
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
self._log_memory_stats()
|
||||
|
||||
return model
|
||||
|
||||
|
@ -506,8 +510,7 @@ class LoraFinetuningSingleDevice:
|
|||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
}
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
log_dict.update(memory_stats)
|
||||
self._log_memory_stats()
|
||||
|
||||
if self._clip_grad_norm is not None:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
|
|
|
@ -5,7 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
LLMAsJudgeScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
GRADER_TEMPLATE = """
|
||||
Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
|
||||
|
@ -87,5 +91,6 @@ llm_as_judge_405b_simpleqa = ScoringFn(
|
|||
judge_model="meta-llama/Llama-3.1-405B-Instruct",
|
||||
prompt_template=GRADER_TEMPLATE,
|
||||
judge_score_regexes=[r"(A|B|C)"],
|
||||
aggregation_functions=[AggregationFunctionType.categorical_count.value],
|
||||
),
|
||||
)
|
||||
|
|
|
@ -9,10 +9,11 @@ from typing import Any, Dict
|
|||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -52,14 +54,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
USING vec0(embedding FLOAT[{self.dimension}]);
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
self.connection.commit()
|
||||
|
||||
|
@ -69,9 +71,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
self.connection.commit()
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
"""
|
||||
Add new chunks along with their embeddings.
|
||||
Add new chunks along with their embeddings using batch inserts.
|
||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
|
@ -80,21 +82,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for chunk, emb in zip(chunks, embeddings, strict=False):
|
||||
# Serialize and insert the chunk metadata.
|
||||
chunk_json = chunk.model_dump_json()
|
||||
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
|
||||
row_id = cur.lastrowid
|
||||
# Ensure the embedding is a list of floats.
|
||||
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
|
||||
# Commit transaction if all inserts succeed
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
# Prepare metadata inserts
|
||||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
metadata_data,
|
||||
)
|
||||
# Prepare embeddings inserts
|
||||
embedding_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
self.connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.connection.rollback() # Rollback on failure
|
||||
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
|
||||
finally:
|
||||
cur.close() # Ensure cursor is closed
|
||||
|
@ -110,7 +126,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.rowid
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
|
@ -204,7 +220,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index’s add_chunks.
|
||||
# and then call our index's add_chunks.
|
||||
await self.cache[vector_db_id].insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
|
@ -213,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
|
|
@ -215,4 +215,14 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -8,8 +8,6 @@ import json
|
|||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from botocore.client import BaseClient
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -27,12 +25,10 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -47,29 +43,15 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
|
@ -134,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
|
||||
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_chat_completion(request)
|
||||
|
@ -152,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||
|
@ -166,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if sampling_params.repetition_penalty > 0:
|
||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
return {
|
||||
"modelId": bedrock_model,
|
||||
"body": json.dumps(
|
||||
|
|
25
llama_stack/providers/remote/inference/bedrock/models.py
Normal file
25
llama_stack/providers/remote/inference/bedrock/models.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
|
@ -7,8 +7,6 @@
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -26,10 +24,9 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy
|
||||
from llama_stack.models.llama.datatypes import TopKSamplingStrategy
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -44,27 +41,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
model_aliases = [
|
||||
build_model_alias(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
]
|
||||
from .models import model_entries
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=model_aliases,
|
||||
model_entries=model_entries,
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
|
@ -107,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
r = await self.client.completions.create(**params)
|
||||
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -154,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
r = await self.client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
@ -170,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
prompt = ""
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
prompt = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
elif isinstance(request, CompletionRequest):
|
||||
prompt = await completion_request_to_prompt(request, self.formatter)
|
||||
prompt = await completion_request_to_prompt(request)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
|
|
21
llama_stack/providers/remote/inference/cerebras/models.py
Normal file
21
llama_stack/providers/remote/inference/cerebras/models.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
model_entries = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
]
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -27,7 +25,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -40,12 +38,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
model_aliases = [
|
||||
build_model_alias(
|
||||
model_entries = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
|
@ -54,12 +52,8 @@ model_aliases = [
|
|||
|
||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=model_aliases,
|
||||
)
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -29,10 +27,8 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
|
@ -51,56 +47,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -149,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -161,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
|
@ -230,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -244,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
@ -258,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
|
@ -284,8 +237,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimensions"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
if model.metadata.get("embedding_dimension"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Fireworks does not support media for embeddings"
|
||||
)
|
||||
|
|
63
llama_stack/providers/remote/inference/fireworks/models.py
Normal file
63
llama_stack/providers/remote/inference/fireworks/models.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -31,8 +31,8 @@ from llama_stack.models.llama.sku_list import CoreModelId
|
|||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_model_alias_with_just_provider_model_id,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
from .groq_utils import (
|
||||
|
@ -41,20 +41,20 @@ from .groq_utils import (
|
|||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
_MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
build_model_entry(
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
|
@ -62,7 +62,7 @@ _MODEL_ALIASES = [
|
|||
# Preview models aren't recommended for production use, but we include this one
|
||||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
|
@ -73,7 +73,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
_config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
def completion(
|
||||
|
|
51
llama_stack/providers/remote/inference/nvidia/models.py
Normal file
51
llama_stack/providers/remote/inference/nvidia/models.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
_MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
|
@ -26,19 +26,14 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
CoreModelId,
|
||||
SamplingParams,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .models import _MODEL_ENTRIES
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
|
@ -51,52 +46,11 @@ from .utils import _is_nvidia_hosted, check_health
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
|
|
103
llama_stack/providers/remote/inference/ollama/models.py
Normal file
103
llama_stack/providers/remote/inference/ollama/models.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
model_entries = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:70b-instruct-fp16",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:405b-instruct-fp16",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:405b",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:1b-instruct-fp16",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:3b",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:11b-instruct-fp16",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:latest",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:90b-instruct-fp16",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:90b",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm:latest",
|
||||
aliases=["all-minilm"],
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-embed-text",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -8,8 +8,6 @@ import logging
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -33,12 +31,9 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_model_alias_with_just_provider_model_id,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -58,87 +53,15 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
request_has_media,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from .models import model_entries
|
||||
|
||||
model_aliases = [
|
||||
build_model_alias(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.1:8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.1:70b-instruct-fp16",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.1:70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.1:405b-instruct-fp16",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.1:405b",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.2:1b-instruct-fp16",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2:3b",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.2-vision:11b-instruct-fp16",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2-vision:latest",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.2-vision:90b-instruct-fp16",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2-vision:90b",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_model_alias(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
]
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.register_helper = ModelRegistryHelper(model_aliases)
|
||||
self.register_helper = ModelRegistryHelper(model_entries)
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
|
@ -197,7 +120,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
|
@ -212,7 +135,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
return process_completion_response(response)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -262,11 +185,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Ollama does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
input_dict["raw"] = True
|
||||
|
||||
if fmt := request.response_format:
|
||||
|
@ -304,7 +226,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -330,7 +252,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
|
@ -352,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
async def check_model_availability(model_id: str):
|
||||
response = await self.client.ps()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
await check_model_availability(model.provider_resource_id)
|
||||
return model
|
||||
response = await self.client.list()
|
||||
else:
|
||||
response = await self.client.ps()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
||||
model = await self.register_helper.register_model(model)
|
||||
await check_model_availability(model.provider_resource_id)
|
||||
|
||||
return model
|
||||
return await self.register_helper.register_model(model)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
||||
|
||||
class PassthroughProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
|
||||
from .passthrough import PassthroughInferenceAdapter
|
||||
|
||||
assert isinstance(config, PassthroughImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = PassthroughInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
31
llama_stack/providers/remote/inference/passthrough/config.py
Normal file
31
llama_stack/providers/remote/inference/passthrough/config.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the passthrough endpoint",
|
||||
)
|
||||
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="API Key for the passthrouth endpoint",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.PASSTHROUGH_URL}",
|
||||
"api_key": "${env.PASSTHROUGH_API_KEY}",
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
||||
|
||||
class PassthroughInferenceAdapter(Inference):
|
||||
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, [])
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
def _get_client(self) -> LlamaStackClient:
|
||||
passthrough_url = None
|
||||
passthrough_api_key = None
|
||||
provider_data = None
|
||||
|
||||
if self.config.url is not None:
|
||||
passthrough_url = self.config.url
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.passthrough_url:
|
||||
raise ValueError(
|
||||
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
|
||||
)
|
||||
passthrough_url = provider_data.passthrough_url
|
||||
|
||||
if self.config.api_key is not None:
|
||||
passthrough_api_key = self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.passthrough_api_key:
|
||||
raise ValueError(
|
||||
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
|
||||
)
|
||||
passthrough_api_key = provider_data.passthrough_api_key
|
||||
|
||||
return LlamaStackClient(
|
||||
base_url=passthrough_url,
|
||||
api_key=passthrough_api_key,
|
||||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"content": content,
|
||||
"sampling_params": sampling_params,
|
||||
"response_format": response_format,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.completion(**params)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"tool_prompt_format": tool_prompt_format,
|
||||
"response_format": response_format,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.chat_completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
return client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
)
|
|
@ -5,8 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"prompt": chat_completion_request_to_prompt(request),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
from .sambanova import SambaNovaInferenceAdapter
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
|
@ -15,6 +14,8 @@ class SambaNovaProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
|
||||
from .sambanova import SambaNovaInferenceAdapter
|
||||
|
||||
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SambaNovaInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
49
llama_stack/providers/remote/inference/sambanova/models.py
Normal file
49
llama_stack/providers/remote/inference/sambanova/models.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-70B-Instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-405B-Instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.2-1B-Instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.2-3B-Instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-3.2-11B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-3.2-90B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
|
@ -7,8 +7,6 @@
|
|||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -18,14 +16,12 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
CoreModelId,
|
||||
GreedySamplingStrategy,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -35,56 +31,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Meta-Llama-3.1-70B-Instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Meta-Llama-3.1-405B-Instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Meta-Llama-3.2-1B-Instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Meta-Llama-3.2-3B-Instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Llama-3.2-11B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Llama-3.2-90B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: SambaNovaImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=MODEL_ALIASES,
|
||||
)
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -160,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
from typing import Union
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
|
||||
_deps,
|
||||
):
|
||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
||||
|
||||
if isinstance(config, TGIImplConfig):
|
||||
impl = TGIAdapter()
|
||||
elif isinstance(config, InferenceAPIImplConfig):
|
||||
|
|
|
@ -9,8 +9,6 @@ import logging
|
|||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -34,7 +32,7 @@ from llama_stack.models.llama.sku_list import all_registered_models
|
|||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -55,9 +53,9 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_model_aliases():
|
||||
def build_hf_repo_model_entries():
|
||||
return [
|
||||
build_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
model.huggingface_repo,
|
||||
model.descriptor(),
|
||||
)
|
||||
|
@ -72,8 +70,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||
}
|
||||
|
@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
return options
|
||||
|
||||
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter)
|
||||
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
|
||||
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
|
@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
|
@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
return process_completion_response(response)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = await chat_completion_request_to_model_input_info(
|
||||
request, self.register_helper.get_llama_model(request.model), self.formatter
|
||||
request, self.register_helper.get_llama_model(request.model)
|
||||
)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
|
|
67
llama_stack/providers/remote/inference/together/models.py
Normal file
67
llama_stack/providers/remote/inference/together/models.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 32768,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -28,10 +26,8 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
|
@ -50,52 +46,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -142,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -154,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
|
@ -220,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
r = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -235,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
@ -246,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
|
|
|
@ -8,8 +8,6 @@ import logging
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import StopReason, ToolCall
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
||||
|
@ -40,7 +38,7 @@ from llama_stack.models.llama.sku_list import all_registered_models
|
|||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionResponse,
|
||||
|
@ -64,9 +62,9 @@ from .config import VLLMInferenceAdapterConfig
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_model_aliases():
|
||||
def build_hf_repo_model_entries():
|
||||
return [
|
||||
build_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
model.huggingface_repo,
|
||||
model.descriptor(),
|
||||
)
|
||||
|
@ -150,19 +148,36 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=ToolCall(
|
||||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=json.loads(tool_call_buf.arguments),
|
||||
args_str = tool_call_buf.arguments
|
||||
args = None
|
||||
try:
|
||||
args = {} if not args_str else json.loads(args_str)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
|
||||
if args is not None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=ToolCall(
|
||||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=args,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=str(tool_call_buf),
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
|
@ -189,9 +204,8 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -286,14 +300,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if len(request.tools) > 0:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
res = process_chat_completion_stream_response(stream, self.formatter, request)
|
||||
res = process_chat_completion_stream_response(stream, request)
|
||||
async for chunk in res:
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self.client.completions.create(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -305,7 +319,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
|
@ -332,10 +346,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||
else:
|
||||
assert not request_has_media(request), "vLLM does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(
|
||||
request,
|
||||
self.formatter,
|
||||
)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
|
@ -364,8 +375,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
assert model.metadata.get("embedding_dimensions")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert model.metadata.get("embedding_dimension")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
|
||||
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||
|
@ -15,6 +14,8 @@ class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -12,6 +12,20 @@ We use `pytest` and all of its dynamism to enable the features needed. Specifica
|
|||
|
||||
- We use `pytest_collection_modifyitems` to filter tests based on the test config (if specified).
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
Your development environment should have been configured as per the instructions in the
|
||||
[CONTRIBUTING.md](../../../CONTRIBUTING.md) file. In particular, make sure to install the test extra
|
||||
dependencies. Below is the full configuration:
|
||||
|
||||
|
||||
```bash
|
||||
$ cd llama-stack
|
||||
$ uv sync --extra dev --extra test
|
||||
$ uv pip install -e .
|
||||
$ source .venv/bin/activate
|
||||
```
|
||||
|
||||
## Common options
|
||||
|
||||
All tests support a `--providers` option which can be a string of the form `api1=provider_fixture1,api2=provider_fixture2`. So, when testing safety (which need inference and safety APIs) you can use `--providers inference=together,safety=meta_reference` to use these fixtures in concert.
|
||||
|
@ -50,6 +64,9 @@ pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
|
|||
--env FIREWORKS_API_KEY=<...>
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you’re using `uv`, you can isolate test executions by prefixing all commands with `uv run pytest...`.
|
||||
|
||||
## Agents
|
||||
|
||||
The Agents API composes three other APIs underneath:
|
||||
|
@ -87,3 +104,6 @@ pytest llama_stack/providers/tests/ --config=ci_test_config.yaml
|
|||
Currently, we support test config on inference, agents and memory api tests.
|
||||
|
||||
Example format of test config can be found in ci_test_config.yaml.
|
||||
|
||||
## Test Data
|
||||
We encourage providers to use our test data for internal development testing, so to make it easier and consistent with the tests we provide. Each test case may define its own data format, and please refer to our test source code to get details on how these fields are used in the test.
|
||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
|||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
|
@ -83,17 +83,13 @@ def inference_cerebras() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_ollama(inference_model) -> ProviderFixture:
|
||||
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
|
||||
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||
|
||||
def inference_ollama() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="ollama",
|
||||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig(host="localhost", port=os.getenv("OLLAMA_PORT", 11434)).model_dump(),
|
||||
config=OllamaImplConfig(url=os.getenv("OLLAMA_URL", DEFAULT_OLLAMA_URL)).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# How to run this test:
|
||||
|
@ -15,6 +13,9 @@ import pytest
|
|||
|
||||
|
||||
class TestModelRegistration:
|
||||
def provider_supports_custom_names(self, provider) -> bool:
|
||||
return "remote::ollama" not in provider.__provider_spec__.provider_type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
||||
inference_impl, models_impl = inference_stack
|
||||
|
@ -47,7 +48,12 @@ class TestModelRegistration:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_llama_model(self, inference_stack):
|
||||
async def test_register_with_llama_model(self, inference_stack, inference_model):
|
||||
inference_impl, models_impl = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if not self.provider_supports_custom_names(provider):
|
||||
pytest.skip("Provider does not support custom model names")
|
||||
|
||||
_, models_impl = inference_stack
|
||||
|
||||
_ = await models_impl.register_model(
|
||||
|
@ -67,22 +73,6 @@ class TestModelRegistration:
|
|||
provider_model_id="custom-model",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_model_during_registering(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
with patch(
|
||||
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_load_model:
|
||||
_ = await models_impl.register_model(
|
||||
model_id="Llama3.1-8B-Instruct",
|
||||
metadata={
|
||||
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
)
|
||||
mock_load_model.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
SystemMessage,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
|
@ -30,6 +31,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
|
@ -178,8 +180,9 @@ class TestInference:
|
|||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["completion-01"])
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class Output(BaseModel):
|
||||
|
@ -187,7 +190,9 @@ class TestInference:
|
|||
year_born: str
|
||||
year_retired: str
|
||||
|
||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||
tc = TestCase(test_case)
|
||||
|
||||
user_input = tc["user_input"]
|
||||
response = await inference_impl.completion(
|
||||
model_id=inference_model,
|
||||
content=user_input,
|
||||
|
@ -203,9 +208,10 @@ class TestInference:
|
|||
assert isinstance(response.content, str)
|
||||
|
||||
answer = Output.model_validate_json(response.content)
|
||||
assert answer.name == "Michael Jordan"
|
||||
assert answer.year_born == "1963"
|
||||
assert answer.year_retired == "2003"
|
||||
expected = tc["expected"]
|
||||
assert answer.name == expected["name"]
|
||||
assert answer.year_born == expected["year_born"]
|
||||
assert answer.year_retired == expected["year_retired"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_non_streaming(
|
||||
|
@ -224,8 +230,9 @@ class TestInference:
|
|||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_structured_output(self, inference_model, inference_stack, common_params):
|
||||
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
|
@ -234,20 +241,12 @@ class TestInference:
|
|||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
tc = TestCase(test_case)
|
||||
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
messages=[
|
||||
# we include context about Michael Jordan in the prompt so that the test is
|
||||
# focused on the funtionality of the model and not on the information embedded
|
||||
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
|
||||
SystemMessage(
|
||||
content=(
|
||||
"You are a helpful assistant.\n\n"
|
||||
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
||||
)
|
||||
),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
],
|
||||
messages=messages,
|
||||
stream=False,
|
||||
response_format=JsonSchemaResponseFormat(
|
||||
json_schema=AnswerFormat.model_json_schema(),
|
||||
|
@ -260,10 +259,11 @@ class TestInference:
|
|||
assert isinstance(response.completion_message.content, str)
|
||||
|
||||
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
assert answer.first_name == "Michael"
|
||||
assert answer.last_name == "Jordan"
|
||||
assert answer.year_of_birth == 1963
|
||||
assert answer.num_seasons_in_nba == 15
|
||||
expected = tc["expected"]
|
||||
assert answer.first_name == expected["first_name"]
|
||||
assert answer.last_name == expected["last_name"]
|
||||
assert answer.year_of_birth == expected["year_of_birth"]
|
||||
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
|
|
|
@ -83,7 +83,6 @@ def pytest_generate_tests(metafunc):
|
|||
if "safety_shield" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if shield_id:
|
||||
assert shield_id.startswith("meta-llama/")
|
||||
params = [pytest.param(shield_id, id="")]
|
||||
else:
|
||||
params = SAFETY_SHIELD_PARAMS
|
||||
|
|
5
llama_stack/providers/tests/test_cases/__init__.py
Normal file
5
llama_stack/providers/tests/test_cases/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
24
llama_stack/providers/tests/test_cases/chat_completion.json
Normal file
24
llama_stack/providers/tests/test_cases/chat_completion.json
Normal file
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"01": {
|
||||
"name": "structured output",
|
||||
"data": {
|
||||
"notes": "We include context about Michael Jordan in the prompt so that the test is focused on the funtionality of the model and not on the information embedded in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please give me information about Michael Jordan."
|
||||
}
|
||||
],
|
||||
"expected": {
|
||||
"first_name": "Michael",
|
||||
"last_name": "Jordan",
|
||||
"year_of_birth": 1963,
|
||||
"num_seasons_in_nba": 15
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
13
llama_stack/providers/tests/test_cases/completion.json
Normal file
13
llama_stack/providers/tests/test_cases/completion.json
Normal file
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"01": {
|
||||
"name": "structured output",
|
||||
"data": {
|
||||
"user_input": "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.",
|
||||
"expected": {
|
||||
"name": "Michael Jordan",
|
||||
"year_born": "1963",
|
||||
"year_retired": "2003"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
32
llama_stack/providers/tests/test_cases/test_case.py
Normal file
32
llama_stack/providers/tests/test_cases/test_case.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
|
||||
class TestCase:
|
||||
_apis = ["chat_completion", "completion"]
|
||||
_jsonblob = {}
|
||||
|
||||
def __init__(self, name):
|
||||
# loading all test cases
|
||||
if self._jsonblob == {}:
|
||||
for api in self._apis:
|
||||
with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f:
|
||||
TestCase._jsonblob.update({f"{api}-{k}": v for k, v in json.load(f).items()})
|
||||
|
||||
# loading this test case
|
||||
tc = self._jsonblob.get(name)
|
||||
if tc is None:
|
||||
raise ValueError(f"Test case {name} not found")
|
||||
|
||||
# these are the only fields we need
|
||||
self.name = tc.get("name")
|
||||
self.data = tc.get("data")
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
160
llama_stack/providers/tests/vector_io/test_sqlite_vec.py
Normal file
160
llama_stack/providers/tests/vector_io/test_sqlite_vec.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sqlite_vec
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
generate_chunk_id,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
SQLITE_VEC_PROVIDER = "sqlite_vec"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def sqlite_connection(loop):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
try:
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def sqlite_vec_index(sqlite_connection):
|
||||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings(sample_chunks):
|
||||
np.random.seed(42)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
assert count == len(sample_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
|
||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||
# Reduce batch size to force multiple batches for same document
|
||||
# since there are 10 chunks per document and batch size is 2
|
||||
batch_size = 2
|
||||
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
||||
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
|
||||
# Ensure all chunk IDs are unique
|
||||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def sqlite_vec_adapter(sqlite_connection):
|
||||
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
|
||||
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
def test_generate_chunk_id():
|
||||
chunks = [
|
||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||
]
|
||||
|
||||
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
|
||||
assert chunk_ids == [
|
||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
]
|
|
@ -4,8 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
@ -14,7 +15,15 @@ from llama_stack.providers.utils.inference import (
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
|
||||
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
# more closer to the Model class.
|
||||
class ProviderModelEntry(BaseModel):
|
||||
provider_model_id: str
|
||||
aliases: List[str] = Field(default_factory=list)
|
||||
llama_model: Optional[str] = None
|
||||
model_type: ModelType = ModelType.llm
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
||||
|
@ -24,8 +33,8 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
def build_hf_repo_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[
|
||||
get_huggingface_repo(model_descriptor),
|
||||
|
@ -34,26 +43,29 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
|
|||
)
|
||||
|
||||
|
||||
def build_model_alias_with_just_provider_model_id(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[],
|
||||
llama_model=model_descriptor,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
def __init__(self, model_aliases: List[ModelAlias]):
|
||||
def __init__(self, model_entries: List[ProviderModelEntry]):
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for alias_obj in model_aliases:
|
||||
for alias in alias_obj.aliases:
|
||||
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||
for entry in model_entries:
|
||||
for alias in entry.aliases:
|
||||
self.alias_to_provider_id_map[alias] = entry.provider_model_id
|
||||
|
||||
# also add a mapping from provider model id to itself for easy lookup
|
||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
|
||||
# ensure we can go from llama model to provider model id
|
||||
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
|
||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
|
||||
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
|
||||
|
||||
if entry.llama_model:
|
||||
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
||||
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> Optional[str]:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
|
|
@ -7,7 +7,6 @@ import json
|
|||
import logging
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from openai.types.chat import ChatCompletionMessageToolCall
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
decode_assistant_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio
|
|||
return None
|
||||
|
||||
|
||||
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
|
||||
def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
|
@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
|
|||
|
||||
def process_chat_completion_response(
|
||||
response: OpenAICompatCompletionResponse,
|
||||
formatter: ChatFormat,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||
)
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
|
@ -217,7 +214,7 @@ def process_chat_completion_response(
|
|||
|
||||
|
||||
async def process_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
) -> AsyncGenerator:
|
||||
stop_reason = None
|
||||
|
||||
|
@ -254,7 +251,6 @@ async def process_completion_stream_response(
|
|||
|
||||
async def process_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
formatter: ChatFormat,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncGenerator:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
@ -333,7 +329,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||
message = decode_assistant_message(buffer, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
|
@ -13,7 +13,9 @@ import re
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.datatypes import StopReason
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
|
|||
content: RawContent
|
||||
|
||||
|
||||
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
|
||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
|
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
|
|||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
|
||||
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||
|
||||
|
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
|
|||
return content
|
||||
|
||||
|
||||
async def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
) -> str:
|
||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> Tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
|
|
|
@ -18,6 +18,7 @@ class KVStoreType(Enum):
|
|||
redis = "redis"
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
mongodb = "mongodb"
|
||||
|
||||
|
||||
class CommonConfig(BaseModel):
|
||||
|
@ -101,7 +102,30 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
return v
|
||||
|
||||
|
||||
class MongoDBKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
|
||||
host: str = "localhost"
|
||||
port: int = 27017
|
||||
db: str = "llamastack"
|
||||
user: str = None
|
||||
password: Optional[str] = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
"type": "mongodb",
|
||||
"namespace": None,
|
||||
"host": "${env.MONGODB_HOST:localhost}",
|
||||
"port": "${env.MONGODB_PORT:5432}",
|
||||
"db": "${env.MONGODB_DB}",
|
||||
"user": "${env.MONGODB_USER}",
|
||||
"password": "${env.MONGODB_PASSWORD}",
|
||||
"collection_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}",
|
||||
}
|
||||
|
||||
|
||||
KVStoreConfig = Annotated[
|
||||
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
|
||||
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig],
|
||||
Field(discriminator="type", default=KVStoreType.sqlite.value),
|
||||
]
|
||||
|
|
|
@ -11,7 +11,7 @@ from .config import KVStoreConfig, KVStoreType
|
|||
|
||||
|
||||
def kvstore_dependencies():
|
||||
return ["aiosqlite", "psycopg2-binary", "redis"]
|
||||
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
||||
|
||||
|
||||
class InmemoryKVStoreImpl(KVStore):
|
||||
|
@ -44,6 +44,10 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
|||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.mongodb.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
||||
|
|
9
llama_stack/providers/utils/kvstore/mongodb/__init__.py
Normal file
9
llama_stack/providers/utils/kvstore/mongodb/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
__all__ = ["MongoDBKVStoreImpl"]
|
66
llama_stack/providers/utils/kvstore/mongodb/mongodb.py
Normal file
66
llama_stack/providers/utils/kvstore/mongodb/mongodb.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MongoDBKVStoreImpl(KVStore):
|
||||
def __init__(self, config: MongoDBKVStoreConfig):
|
||||
self.config = config
|
||||
self.conn = None
|
||||
self.collection = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
conn_creds = {
|
||||
"host": self.config.host,
|
||||
"port": self.config.port,
|
||||
"username": self.config.user,
|
||||
"password": self.config.password,
|
||||
}
|
||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||
self.conn = MongoClient(**conn_creds)
|
||||
self.collection = self.conn[self.config.db][self.config.collection_name]
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to MongoDB database server")
|
||||
raise RuntimeError("Could not connect to MongoDB database server") from e
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
update_query = {"$set": {"value": value, "expiration": expiration}}
|
||||
self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
key = self._namespaced_key(key)
|
||||
query = {"key": key}
|
||||
result = self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
return result["value"] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.collection.delete_one({"key": key})
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {
|
||||
"key": {"$gte": start_key, "$lt": end_key},
|
||||
}
|
||||
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
|
||||
return [doc["value"] for doc in cursor]
|
|
@ -19,6 +19,7 @@ class WebMethod:
|
|||
request_examples: Optional[List[Any]] = None
|
||||
response_examples: Optional[List[Any]] = None
|
||||
method: Optional[str] = None
|
||||
raw_bytes_request_body: Optional[bool] = False
|
||||
|
||||
|
||||
def webmethod(
|
||||
|
@ -27,6 +28,7 @@ def webmethod(
|
|||
public: Optional[bool] = False,
|
||||
request_examples: Optional[List[Any]] = None,
|
||||
response_examples: Optional[List[Any]] = None,
|
||||
raw_bytes_request_body: Optional[bool] = False,
|
||||
) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator that supplies additional metadata to an endpoint operation function.
|
||||
|
@ -44,6 +46,7 @@ def webmethod(
|
|||
public=public or False,
|
||||
request_examples=request_examples,
|
||||
response_examples=response_examples,
|
||||
raw_bytes_request_body=raw_bytes_request_body,
|
||||
)
|
||||
return cls
|
||||
|
||||
|
|
|
@ -23,6 +23,22 @@ from llama_stack.distribution.build import (
|
|||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ChangedPathTracker:
|
||||
"""Track a list of paths we may have changed."""
|
||||
|
||||
def __init__(self):
|
||||
self._changed_paths = []
|
||||
|
||||
def add_paths(self, *paths):
|
||||
for path in paths:
|
||||
path = str(path)
|
||||
if path not in self._changed_paths:
|
||||
self._changed_paths.append(path)
|
||||
|
||||
def changed_paths(self):
|
||||
return self._changed_paths
|
||||
|
||||
|
||||
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
||||
"""Find immediate subdirectories in the templates folder."""
|
||||
if not templates_dir.exists():
|
||||
|
@ -31,7 +47,7 @@ def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
|||
return sorted(d for d in templates_dir.iterdir() if d.is_dir() and d.name != "__pycache__")
|
||||
|
||||
|
||||
def process_template(template_dir: Path, progress) -> None:
|
||||
def process_template(template_dir: Path, progress, change_tracker: ChangedPathTracker) -> None:
|
||||
"""Process a single template directory."""
|
||||
progress.print(f"Processing {template_dir.name}")
|
||||
|
||||
|
@ -44,9 +60,12 @@ def process_template(template_dir: Path, progress) -> None:
|
|||
if template_func := getattr(module, "get_distribution_template", None):
|
||||
template = template_func()
|
||||
|
||||
yaml_output_dir = REPO_ROOT / "llama_stack" / "templates" / template.name
|
||||
doc_output_dir = REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro"
|
||||
change_tracker.add_paths(yaml_output_dir, doc_output_dir)
|
||||
template.save_distribution(
|
||||
yaml_output_dir=REPO_ROOT / "llama_stack" / "templates" / template.name,
|
||||
doc_output_dir=REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro",
|
||||
yaml_output_dir=yaml_output_dir,
|
||||
doc_output_dir=doc_output_dir,
|
||||
)
|
||||
else:
|
||||
progress.print(f"[yellow]Warning: {template_dir.name} has no get_distribution_template function")
|
||||
|
@ -56,14 +75,19 @@ def process_template(template_dir: Path, progress) -> None:
|
|||
raise e
|
||||
|
||||
|
||||
def check_for_changes() -> bool:
|
||||
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
|
||||
"""Check if there are any uncommitted changes."""
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--exit-code"],
|
||||
cwd=REPO_ROOT,
|
||||
capture_output=True,
|
||||
)
|
||||
return result.returncode != 0
|
||||
has_changes = False
|
||||
for path in change_tracker.changed_paths():
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--exit-code", path],
|
||||
cwd=REPO_ROOT,
|
||||
capture_output=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"Change detected in '{path}'.", file=sys.stderr)
|
||||
has_changes = True
|
||||
return has_changes
|
||||
|
||||
|
||||
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
||||
|
@ -83,7 +107,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
|||
return None, []
|
||||
|
||||
|
||||
def generate_dependencies_file():
|
||||
def generate_dependencies_file(change_tracker: ChangedPathTracker):
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
distribution_deps = {}
|
||||
|
||||
|
@ -93,12 +117,14 @@ def generate_dependencies_file():
|
|||
distribution_deps[name] = deps
|
||||
|
||||
deps_file = REPO_ROOT / "distributions" / "dependencies.json"
|
||||
change_tracker.add_paths(deps_file)
|
||||
with open(deps_file, "w") as f:
|
||||
f.write(json.dumps(distribution_deps, indent=2) + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
change_tracker = ChangedPathTracker()
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
|
@ -108,7 +134,7 @@ def main():
|
|||
task = progress.add_task("Processing distribution templates...", total=len(template_dirs))
|
||||
|
||||
# Create a partial function with the progress bar
|
||||
process_func = partial(process_template, progress=progress)
|
||||
process_func = partial(process_template, progress=progress, change_tracker=change_tracker)
|
||||
|
||||
# Process templates in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
|
@ -116,9 +142,9 @@ def main():
|
|||
list(executor.map(process_func, template_dirs))
|
||||
progress.update(task, advance=len(template_dirs))
|
||||
|
||||
generate_dependencies_file()
|
||||
generate_dependencies_file(change_tracker)
|
||||
|
||||
if check_for_changes():
|
||||
if check_for_changes(change_tracker):
|
||||
print(
|
||||
"Distribution template changes detected. Please commit the changes.",
|
||||
file=sys.stderr,
|
||||
|
|
66
llama_stack/scripts/run_client_sdk_tests.py
Normal file
66
llama_stack/scripts/run_client_sdk_tests.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
"""
|
||||
Script for running client-sdk on AsyncLlamaStackAsLibraryClient with templates
|
||||
|
||||
Assuming directory structure:
|
||||
- llama-stack
|
||||
- llama_stack
|
||||
- scripts
|
||||
- tests
|
||||
- client-sdk
|
||||
|
||||
Example command:
|
||||
|
||||
cd llama-stack
|
||||
EXPORT TOGETHER_API_KEY=<..>
|
||||
EXPORT FIREWORKS_API_KEY=<..>
|
||||
python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report
|
||||
"""
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/client-sdk/"
|
||||
|
||||
|
||||
def main(parser: argparse.ArgumentParser):
|
||||
args = parser.parse_args()
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
user_specified_templates = [templates_dir / t for t in args.templates] if args.templates else []
|
||||
for d in templates_dir.iterdir():
|
||||
if d.is_dir() and d.name != "__pycache__":
|
||||
template_configs = list(d.rglob("run.yaml"))
|
||||
if len(template_configs) == 0:
|
||||
continue
|
||||
config = template_configs[0]
|
||||
if user_specified_templates:
|
||||
if not any(config.parent == t for t in user_specified_templates):
|
||||
continue
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config)
|
||||
pytest_args = "--report" if args.report else ""
|
||||
pytest.main(
|
||||
[
|
||||
pytest_args,
|
||||
"-s",
|
||||
"-v",
|
||||
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama_test",
|
||||
)
|
||||
parser.add_argument("--templates", nargs="+")
|
||||
parser.add_argument("--report", action="store_true")
|
||||
main(parser)
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelInput
|
|||
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="bedrock",
|
||||
)
|
||||
for m in MODEL_ALIASES
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
|
||||
from llama_stack.providers.remote.inference.cerebras.models import model_entries
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="cerebras",
|
||||
)
|
||||
for m in model_aliases
|
||||
for m in model_entries
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -63,11 +63,13 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model],
|
||||
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||
provider_model_id=m.provider_model_id,
|
||||
provider_id="fireworks",
|
||||
metadata=m.metadata,
|
||||
model_type=m.model_type,
|
||||
)
|
||||
for m in MODEL_ALIASES
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
|
|
|
@ -149,6 +149,13 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 768
|
||||
context_length: 8192
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
provider_id: fireworks
|
||||
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
@ -143,6 +143,13 @@ models:
|
|||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 768
|
||||
context_length: 8192
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
provider_id: fireworks
|
||||
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
|||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -59,7 +59,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="nvidia",
|
||||
)
|
||||
for m in _MODEL_ALIASES
|
||||
for m in _MODEL_ENTRIES
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
|
|
|
@ -6,7 +6,6 @@ distribution_spec:
|
|||
- remote::ollama
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- inline::sqlite_vec
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
|
|
|
@ -119,7 +119,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -71,7 +71,8 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
)
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
provider_id="ollama",
|
||||
provider_model_id="all-minilm:latest",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
|
|
|
@ -20,6 +20,13 @@ providers:
|
|||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
|
@ -103,7 +110,8 @@ models:
|
|||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_id: ollama
|
||||
provider_model_id: all-minilm:latest
|
||||
model_type: embedding
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue