mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 04:28:02 +00:00
Merge branch 'main' into jwm4-add-qdrant-to-provider-tests
This commit is contained in:
commit
89411c839a
198 changed files with 1432 additions and 646 deletions
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -8,4 +8,3 @@
|
||||||
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
||||||
|
|
||||||
[//]: # (## Documentation)
|
[//]: # (## Documentation)
|
||||||
[//]: # (- [ ] Added a Changelog entry if the change is significant)
|
|
||||||
|
|
|
@ -29,10 +29,12 @@ repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.4
|
rev: v0.9.4
|
||||||
hooks:
|
hooks:
|
||||||
|
# Run the linter with import sorting.
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [
|
args: [
|
||||||
--fix,
|
--fix,
|
||||||
--exit-non-zero-on-fix
|
--exit-non-zero-on-fix,
|
||||||
|
--select, I,
|
||||||
]
|
]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@ If you need help or guidance, comment on the issue. Issues that are extra friend
|
||||||
3. Ensure the test suite passes.
|
3. Ensure the test suite passes.
|
||||||
4. Make sure your code lints using `pre-commit`.
|
4. Make sure your code lints using `pre-commit`.
|
||||||
5. If you haven't already, complete the Contributor License Agreement ("CLA").
|
5. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||||
|
6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/).
|
||||||
|
|
||||||
## Contributor License Agreement ("CLA")
|
## Contributor License Agreement ("CLA")
|
||||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||||
|
@ -98,7 +99,8 @@ $ uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
## Coding Style
|
## Coding Style
|
||||||
* 2 spaces for indentation rather than tabs
|
|
||||||
|
* 4 spaces for indentation rather than tabs
|
||||||
* 80 character line length
|
* 80 character line length
|
||||||
* ...
|
* ...
|
||||||
|
|
||||||
|
|
33
README.md
33
README.md
|
@ -7,13 +7,13 @@
|
||||||
|
|
||||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
||||||
|
|
||||||
Llama Stack defines and standardizes the core building blocks that simplify AI application development. It codified best practices across the Llama ecosystem. More specifically, it provides
|
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
|
||||||
|
|
||||||
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.
|
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.
|
||||||
- **Plugin architecture** to support the rich ecosystem of implementations of the different APIs in different environments like local development, on-premises, cloud, and mobile.
|
- **Plugin architecture** to support the rich ecosystem of different API implementations in various environments, including local development, on-premises, cloud, and mobile.
|
||||||
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment
|
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment.
|
||||||
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android
|
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android.
|
||||||
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
|
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack.
|
||||||
|
|
||||||
<div style="text-align: center;">
|
<div style="text-align: center;">
|
||||||
<img
|
<img
|
||||||
|
@ -25,14 +25,14 @@ Llama Stack defines and standardizes the core building blocks that simplify AI a
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
### Llama Stack Benefits
|
### Llama Stack Benefits
|
||||||
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choice.
|
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choices.
|
||||||
- **Consistent Experience**: With its unified APIs Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
|
- **Consistent Experience**: With its unified APIs, Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
|
||||||
- **Robust Ecosystem**: Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies) that offer tailored infrastructure, software, and services for deploying Llama models.
|
- **Robust Ecosystem**: Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies) that offer tailored infrastructure, software, and services for deploying Llama models.
|
||||||
|
|
||||||
By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications.
|
By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications.
|
||||||
|
|
||||||
### API Providers
|
### API Providers
|
||||||
Here is a list of the various API providers and available distributions to developers started easily,
|
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
||||||
|
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
||||||
|
@ -71,15 +71,15 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
||||||
|
|
||||||
You have two ways to install this repository:
|
You have two ways to install this repository:
|
||||||
|
|
||||||
1. **Install as a package**:
|
* **Install as a package**:
|
||||||
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
||||||
```bash
|
```bash
|
||||||
pip install llama-stack
|
pip install llama-stack
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Install from source**:
|
* **Install from source**:
|
||||||
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
|
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
|
||||||
Then, follow these steps:
|
Then, run the following commands:
|
||||||
```bash
|
```bash
|
||||||
mkdir -p ~/local
|
mkdir -p ~/local
|
||||||
cd ~/local
|
cd ~/local
|
||||||
|
@ -96,10 +96,11 @@ You have two ways to install this repository:
|
||||||
|
|
||||||
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
|
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
|
||||||
|
|
||||||
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
|
* CLI references
|
||||||
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
* [llama (server-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html): Guide for using the `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
|
||||||
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
|
* [llama (client-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_stack_client_cli_reference.html): Guide for using the `llama-stack-client` CLI, which allows you to query information about the distribution.
|
||||||
* Quick guide to start a Llama Stack server.
|
* Getting Started
|
||||||
|
* [Quick guide to start a Llama Stack server](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
|
||||||
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
|
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
|
||||||
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
|
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
|
||||||
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
|
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
|
||||||
|
@ -115,6 +116,6 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest
|
||||||
| Typescript | [llama-stack-client-typescript](https://github.com/meta-llama/llama-stack-client-typescript) | [](https://npmjs.org/package/llama-stack-client)
|
| Typescript | [llama-stack-client-typescript](https://github.com/meta-llama/llama-stack-client-typescript) | [](https://npmjs.org/package/llama-stack-client)
|
||||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
|
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
|
||||||
|
|
||||||
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
|
||||||
|
|
||||||
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo.
|
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo.
|
||||||
|
|
176
docs/_static/llama-stack-spec.html
vendored
176
docs/_static/llama-stack-spec.html
vendored
|
@ -3106,6 +3106,12 @@
|
||||||
"ChatCompletionResponse": {
|
"ChatCompletionResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"metrics": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/MetricEvent"
|
||||||
|
}
|
||||||
|
},
|
||||||
"completion_message": {
|
"completion_message": {
|
||||||
"$ref": "#/components/schemas/CompletionMessage",
|
"$ref": "#/components/schemas/CompletionMessage",
|
||||||
"description": "The complete response message"
|
"description": "The complete response message"
|
||||||
|
@ -3124,6 +3130,74 @@
|
||||||
],
|
],
|
||||||
"description": "Response from a chat completion request."
|
"description": "Response from a chat completion request."
|
||||||
},
|
},
|
||||||
|
"MetricEvent": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"trace_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"span_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "date-time"
|
||||||
|
},
|
||||||
|
"attributes": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "metric",
|
||||||
|
"default": "metric"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"value": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"trace_id",
|
||||||
|
"span_id",
|
||||||
|
"timestamp",
|
||||||
|
"type",
|
||||||
|
"metric",
|
||||||
|
"value",
|
||||||
|
"unit"
|
||||||
|
]
|
||||||
|
},
|
||||||
"TokenLogProbs": {
|
"TokenLogProbs": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -3388,6 +3462,12 @@
|
||||||
"ChatCompletionResponseStreamChunk": {
|
"ChatCompletionResponseStreamChunk": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"metrics": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/MetricEvent"
|
||||||
|
}
|
||||||
|
},
|
||||||
"event": {
|
"event": {
|
||||||
"$ref": "#/components/schemas/ChatCompletionResponseEvent",
|
"$ref": "#/components/schemas/ChatCompletionResponseEvent",
|
||||||
"description": "The event containing the new content"
|
"description": "The event containing the new content"
|
||||||
|
@ -3600,8 +3680,7 @@
|
||||||
"auto",
|
"auto",
|
||||||
"required"
|
"required"
|
||||||
],
|
],
|
||||||
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.",
|
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
|
||||||
"default": "auto"
|
|
||||||
},
|
},
|
||||||
"tool_prompt_format": {
|
"tool_prompt_format": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -6374,77 +6453,6 @@
|
||||||
"critical"
|
"critical"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"MetricEvent": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"trace_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"span_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"timestamp": {
|
|
||||||
"type": "string",
|
|
||||||
"format": "date-time"
|
|
||||||
},
|
|
||||||
"attributes": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "array"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"const": "metric",
|
|
||||||
"default": "metric"
|
|
||||||
},
|
|
||||||
"metric": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"value": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"unit": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"trace_id",
|
|
||||||
"span_id",
|
|
||||||
"timestamp",
|
|
||||||
"type",
|
|
||||||
"metric",
|
|
||||||
"value",
|
|
||||||
"unit"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"SpanEndPayload": {
|
"SpanEndPayload": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -6502,22 +6510,19 @@
|
||||||
"additionalProperties": {
|
"additionalProperties": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
"type": "null"
|
"type": "string"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "boolean"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "string"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "array"
|
"type": "null"
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -6575,22 +6580,19 @@
|
||||||
"additionalProperties": {
|
"additionalProperties": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
"type": "null"
|
"type": "string"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "boolean"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "string"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "array"
|
"type": "null"
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
108
docs/_static/llama-stack-spec.yaml
vendored
108
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1925,6 +1925,10 @@ components:
|
||||||
ChatCompletionResponse:
|
ChatCompletionResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
metrics:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/MetricEvent'
|
||||||
completion_message:
|
completion_message:
|
||||||
$ref: '#/components/schemas/CompletionMessage'
|
$ref: '#/components/schemas/CompletionMessage'
|
||||||
description: The complete response message
|
description: The complete response message
|
||||||
|
@ -1938,6 +1942,46 @@ components:
|
||||||
required:
|
required:
|
||||||
- completion_message
|
- completion_message
|
||||||
description: Response from a chat completion request.
|
description: Response from a chat completion request.
|
||||||
|
MetricEvent:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
trace_id:
|
||||||
|
type: string
|
||||||
|
span_id:
|
||||||
|
type: string
|
||||||
|
timestamp:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
attributes:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: integer
|
||||||
|
- type: number
|
||||||
|
- type: boolean
|
||||||
|
- type: 'null'
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: metric
|
||||||
|
default: metric
|
||||||
|
metric:
|
||||||
|
type: string
|
||||||
|
value:
|
||||||
|
oneOf:
|
||||||
|
- type: integer
|
||||||
|
- type: number
|
||||||
|
unit:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- trace_id
|
||||||
|
- span_id
|
||||||
|
- timestamp
|
||||||
|
- type
|
||||||
|
- metric
|
||||||
|
- value
|
||||||
|
- unit
|
||||||
TokenLogProbs:
|
TokenLogProbs:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -2173,6 +2217,10 @@ components:
|
||||||
ChatCompletionResponseStreamChunk:
|
ChatCompletionResponseStreamChunk:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
metrics:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/MetricEvent'
|
||||||
event:
|
event:
|
||||||
$ref: '#/components/schemas/ChatCompletionResponseEvent'
|
$ref: '#/components/schemas/ChatCompletionResponseEvent'
|
||||||
description: The event containing the new content
|
description: The event containing the new content
|
||||||
|
@ -2338,7 +2386,6 @@ components:
|
||||||
Whether tool use is required or automatic. This is a hint to the model
|
Whether tool use is required or automatic. This is a hint to the model
|
||||||
which may not be followed. It depends on the Instruction Following capabilities
|
which may not be followed. It depends on the Instruction Following capabilities
|
||||||
of the model.
|
of the model.
|
||||||
default: auto
|
|
||||||
tool_prompt_format:
|
tool_prompt_format:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
|
@ -4070,47 +4117,6 @@ components:
|
||||||
- warn
|
- warn
|
||||||
- error
|
- error
|
||||||
- critical
|
- critical
|
||||||
MetricEvent:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
trace_id:
|
|
||||||
type: string
|
|
||||||
span_id:
|
|
||||||
type: string
|
|
||||||
timestamp:
|
|
||||||
type: string
|
|
||||||
format: date-time
|
|
||||||
attributes:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
const: metric
|
|
||||||
default: metric
|
|
||||||
metric:
|
|
||||||
type: string
|
|
||||||
value:
|
|
||||||
oneOf:
|
|
||||||
- type: integer
|
|
||||||
- type: number
|
|
||||||
unit:
|
|
||||||
type: string
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- trace_id
|
|
||||||
- span_id
|
|
||||||
- timestamp
|
|
||||||
- type
|
|
||||||
- metric
|
|
||||||
- value
|
|
||||||
- unit
|
|
||||||
SpanEndPayload:
|
SpanEndPayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -4153,12 +4159,11 @@ components:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: integer
|
||||||
- type: object
|
- type: number
|
||||||
|
- type: boolean
|
||||||
|
- type: 'null'
|
||||||
type:
|
type:
|
||||||
type: string
|
type: string
|
||||||
const: structured_log
|
const: structured_log
|
||||||
|
@ -4195,12 +4200,11 @@ components:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: integer
|
||||||
- type: object
|
- type: number
|
||||||
|
- type: boolean
|
||||||
|
- type: 'null'
|
||||||
type:
|
type:
|
||||||
type: string
|
type: string
|
||||||
const: unstructured_log
|
const: unstructured_log
|
||||||
|
|
|
@ -644,7 +644,9 @@ class Generator:
|
||||||
else:
|
else:
|
||||||
callbacks = None
|
callbacks = None
|
||||||
|
|
||||||
description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description]))
|
description = "\n".join(
|
||||||
|
filter(None, [doc_string.short_description, doc_string.long_description])
|
||||||
|
)
|
||||||
return Operation(
|
return Operation(
|
||||||
tags=[op.defining_class.__name__],
|
tags=[op.defining_class.__name__],
|
||||||
summary=None,
|
summary=None,
|
||||||
|
@ -681,6 +683,7 @@ class Generator:
|
||||||
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
||||||
|
|
||||||
route = op.get_route()
|
route = op.get_route()
|
||||||
|
route = route.replace(":path", "")
|
||||||
print(f"route: {route}")
|
print(f"route: {route}")
|
||||||
if route in paths:
|
if route in paths:
|
||||||
paths[route].update(pathItem)
|
paths[route].update(pathItem)
|
||||||
|
|
|
@ -130,6 +130,8 @@ class _FormatParameterExtractor:
|
||||||
|
|
||||||
def _get_route_parameters(route: str) -> List[str]:
|
def _get_route_parameters(route: str) -> List[str]:
|
||||||
extractor = _FormatParameterExtractor()
|
extractor = _FormatParameterExtractor()
|
||||||
|
# Replace all occurrences of ":path" with empty string
|
||||||
|
route = route.replace(":path", "")
|
||||||
route.format_map(extractor)
|
route.format_map(extractor)
|
||||||
return extractor.keys
|
return extractor.keys
|
||||||
|
|
||||||
|
|
|
@ -180,12 +180,45 @@ After this step is successful, you should be able to find the built container im
|
||||||
### Running your Stack server
|
### Running your Stack server
|
||||||
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step.
|
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step.
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack run -h
|
||||||
|
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE]
|
||||||
|
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
||||||
|
config
|
||||||
|
|
||||||
|
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
config Path to config file to use for the run
|
||||||
|
|
||||||
|
options:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--port PORT Port to run the server on. Defaults to 8321
|
||||||
|
--image-name IMAGE_NAME
|
||||||
|
Name of the image to run. Defaults to the current conda environment
|
||||||
|
--disable-ipv6 Disable IPv6 support
|
||||||
|
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.
|
||||||
|
--tls-keyfile TLS_KEYFILE
|
||||||
|
Path to TLS key file for HTTPS
|
||||||
|
--tls-certfile TLS_CERTFILE
|
||||||
|
Path to TLS certificate file for HTTPS
|
||||||
|
--image-type {conda,container,venv}
|
||||||
|
Image Type used during the build. This can be either conda or container or venv.
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
# Start using template name
|
# Start using template name
|
||||||
llama stack run tgi
|
llama stack run tgi
|
||||||
|
|
||||||
# Start using config file
|
# Start using config file
|
||||||
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||||
|
|
||||||
|
# Start using a venv
|
||||||
|
llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||||
|
|
||||||
|
# Start using a conda environment
|
||||||
|
llama stack run --image-type conda ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -15,25 +15,25 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
runtime_checkable,
|
|
||||||
Union,
|
Union,
|
||||||
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
ToolConfig,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
|
@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel):
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
|
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||||
|
|
||||||
|
@ -166,11 +166,13 @@ class AgentConfigCommon(BaseModel):
|
||||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||||
if self.tool_config is None:
|
else:
|
||||||
self.tool_config = ToolConfig(
|
params = {}
|
||||||
tool_choice=self.tool_choice,
|
if self.tool_choice:
|
||||||
tool_prompt_format=self.tool_prompt_format,
|
params["tool_choice"] = self.tool_choice
|
||||||
)
|
if self.tool_prompt_format:
|
||||||
|
params["tool_prompt_format"] = self.tool_prompt_format
|
||||||
|
self.tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -333,7 +335,10 @@ class Agents(Protocol):
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
|
@webmethod(
|
||||||
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||||
|
method="GET",
|
||||||
|
)
|
||||||
async def get_agents_turn(
|
async def get_agents_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
|
|
@ -13,7 +13,6 @@ from termcolor import cprint
|
||||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||||
from llama_stack.apis.inference import ToolResponseMessage
|
from llama_stack.apis.inference import ToolResponseMessage
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,6 @@ from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolCall
|
from llama_models.llama3.api.datatypes import ToolCall
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema
|
from llama_models.schema_utils import json_schema_type, register_schema
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
|
|
@ -58,7 +58,7 @@ class Datasets(Protocol):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}", method="GET")
|
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
@ -67,7 +67,7 @@ class Datasets(Protocol):
|
||||||
@webmethod(route="/datasets", method="GET")
|
@webmethod(route="/datasets", method="GET")
|
||||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
|
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||||
async def unregister_dataset(
|
async def unregister_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
|
|
@ -13,8 +13,8 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
runtime_checkable,
|
|
||||||
Union,
|
Union,
|
||||||
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
@ -31,6 +31,7 @@ from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@ -357,7 +358,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseStreamChunk(BaseModel):
|
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||||
"""A chunk of a streamed chat completion response.
|
"""A chunk of a streamed chat completion response.
|
||||||
|
|
||||||
:param event: The event containing the new content
|
:param event: The event containing the new content
|
||||||
|
@ -367,7 +368,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||||
"""Response from a chat completion request.
|
"""Response from a chat completion request.
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
:param completion_message: The complete response message
|
||||||
|
|
|
@ -62,7 +62,7 @@ class Models(Protocol):
|
||||||
@webmethod(route="/models", method="GET")
|
@webmethod(route="/models", method="GET")
|
||||||
async def list_models(self) -> ListModelsResponse: ...
|
async def list_models(self) -> ListModelsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id}", method="GET")
|
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||||
async def get_model(
|
async def get_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -78,7 +78,7 @@ class Models(Protocol):
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id}", method="DELETE")
|
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||||
async def unregister_model(
|
async def unregister_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -12,8 +12,8 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
runtime_checkable,
|
|
||||||
Union,
|
Union,
|
||||||
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
@ -134,7 +134,7 @@ class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring-functions", method="GET")
|
@webmethod(route="/scoring-functions", method="GET")
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
|
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions", method="POST")
|
@webmethod(route="/scoring-functions", method="POST")
|
||||||
|
|
|
@ -48,7 +48,7 @@ class Shields(Protocol):
|
||||||
@webmethod(route="/shields", method="GET")
|
@webmethod(route="/shields", method="GET")
|
||||||
async def list_shields(self) -> ListShieldsResponse: ...
|
async def list_shields(self) -> ListShieldsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/{identifier}", method="GET")
|
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields", method="POST")
|
@webmethod(route="/shields", method="POST")
|
||||||
|
|
|
@ -5,11 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
|
|
@ -13,10 +13,11 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
runtime_checkable,
|
|
||||||
Union,
|
Union,
|
||||||
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import Primitive
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -76,7 +77,7 @@ class EventCommon(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
span_id: str
|
span_id: str
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -94,6 +95,30 @@ class MetricEvent(EventCommon):
|
||||||
unit: str
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
|
# This is a short term solution to allow inference API to return metrics
|
||||||
|
# The ideal way to do this is to have a way for all response types to include metrics
|
||||||
|
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||||
|
# To do this, we will need to augment all response types with a metrics field.
|
||||||
|
# We have hit a blocker from stainless SDK that prevents us from doing this.
|
||||||
|
# The blocker is that if we were to augment the response types that have a data field
|
||||||
|
# in them like so
|
||||||
|
# class ListModelsResponse(BaseModel):
|
||||||
|
# metrics: Optional[List[MetricEvent]] = None
|
||||||
|
# data: List[Models]
|
||||||
|
# ...
|
||||||
|
# The client SDK will need to access the data by using a .data field, which is not
|
||||||
|
# ergonomic. Stainless SDK does support unwrapping the response type, but it
|
||||||
|
# requires that the response type to only have a single field.
|
||||||
|
|
||||||
|
# We will need a way in the client SDK to signal that the metrics are needed
|
||||||
|
# and if they are needed, the client SDK has to return the full response type
|
||||||
|
# without unwrapping it.
|
||||||
|
|
||||||
|
|
||||||
|
class MetricResponseMixin(BaseModel):
|
||||||
|
metrics: Optional[List[MetricEvent]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class StructuredLogType(Enum):
|
class StructuredLogType(Enum):
|
||||||
SPAN_START = "span_start"
|
SPAN_START = "span_start"
|
||||||
|
@ -199,13 +224,13 @@ class Telemetry(Protocol):
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: Optional[List[str]] = None,
|
||||||
) -> QueryTracesResponse: ...
|
) -> QueryTracesResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/{span_id}/tree", method="GET")
|
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
|
||||||
async def get_span_tree(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
|
|
|
@ -4,5 +4,5 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .tools import * # noqa: F401 F403
|
|
||||||
from .rag_tool import * # noqa: F401 F403
|
from .rag_tool import * # noqa: F401 F403
|
||||||
|
from .tools import * # noqa: F401 F403
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ class ToolGroups(Protocol):
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
|
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||||
async def get_tool_group(
|
async def get_tool_group(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
|
@ -117,13 +117,13 @@ class ToolGroups(Protocol):
|
||||||
"""List tools with optional tool group"""
|
"""List tools with optional tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools/{tool_name}", method="GET")
|
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||||
async def get_tool(
|
async def get_tool(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
) -> Tool: ...
|
) -> Tool: ...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
|
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||||
async def unregister_toolgroup(
|
async def unregister_toolgroup(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
|
|
|
@ -46,7 +46,7 @@ class VectorDBs(Protocol):
|
||||||
@webmethod(route="/vector-dbs", method="GET")
|
@webmethod(route="/vector-dbs", method="GET")
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||||
async def get_vector_db(
|
async def get_vector_db(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
|
@ -62,5 +62,5 @@ class VectorDBs(Protocol):
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> VectorDB: ...
|
) -> VectorDB: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||||
|
|
|
@ -16,11 +16,9 @@ from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.datatypes import Model
|
from llama_models.datatypes import Model
|
||||||
from llama_models.sku_list import LlamaDownloadInfo
|
from llama_models.sku_list import LlamaDownloadInfo
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
BarColumn,
|
BarColumn,
|
||||||
|
|
|
@ -8,7 +8,6 @@ import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
|
@ -38,7 +38,7 @@ class ModelList(Subcommand):
|
||||||
|
|
||||||
headers = [
|
headers = [
|
||||||
"Model Descriptor",
|
"Model Descriptor",
|
||||||
"Hugging Face Repo",
|
"Model ID",
|
||||||
"Context Length",
|
"Context Length",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.cli.model.download import ModelDownload
|
||||||
from llama_stack.cli.model.list import ModelList
|
from llama_stack.cli.model.list import ModelList
|
||||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +25,8 @@ class ModelParser(Subcommand):
|
||||||
description="Work with llama models",
|
description="Work with llama models",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||||
|
|
||||||
subparsers = self.parser.add_subparsers(title="model_subcommands")
|
subparsers = self.parser.add_subparsers(title="model_subcommands")
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
|
|
|
@ -8,7 +8,7 @@ import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily
|
from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
|
||||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||||
from llama_models.llama3.api.datatypes import SamplingParams
|
from llama_models.llama3.api.datatypes import SamplingParams
|
||||||
from llama_models.sku_list import LlamaDownloadInfo
|
from llama_models.sku_list import LlamaDownloadInfo
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,12 +21,11 @@ from prompt_toolkit.validation import Validator
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
from llama_stack.distribution.build import (
|
from llama_stack.distribution.build import (
|
||||||
|
SERVER_DEPENDENCIES,
|
||||||
|
ImageType,
|
||||||
build_image,
|
build_image,
|
||||||
get_provider_dependencies,
|
get_provider_dependencies,
|
||||||
ImageType,
|
|
||||||
SERVER_DEPENDENCIES,
|
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BuildConfig,
|
BuildConfig,
|
||||||
|
|
|
@ -21,15 +21,19 @@ class StackListProviders(Subcommand):
|
||||||
self._add_arguments()
|
self._add_arguments()
|
||||||
self.parser.set_defaults(func=self._run_providers_list_cmd)
|
self.parser.set_defaults(func=self._run_providers_list_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
@property
|
||||||
|
def providable_apis(self):
|
||||||
from llama_stack.distribution.distribution import providable_apis
|
from llama_stack.distribution.distribution import providable_apis
|
||||||
|
|
||||||
api_values = [api.value for api in providable_apis()]
|
return [api.value for api in providable_apis()]
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"api",
|
"api",
|
||||||
type=str,
|
type=str,
|
||||||
choices=api_values,
|
choices=self.providable_apis,
|
||||||
help="API to list providers for (one of: {})".format(api_values),
|
nargs="?",
|
||||||
|
help="API to list providers for. List all if not specified.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
@ -37,20 +41,29 @@ class StackListProviders(Subcommand):
|
||||||
from llama_stack.distribution.distribution import Api, get_provider_registry
|
from llama_stack.distribution.distribution import Api, get_provider_registry
|
||||||
|
|
||||||
all_providers = get_provider_registry()
|
all_providers = get_provider_registry()
|
||||||
providers_for_api = all_providers[Api(args.api)]
|
if args.api:
|
||||||
|
providers = [(args.api, all_providers[Api(args.api)])]
|
||||||
|
else:
|
||||||
|
providers = [(k.value, prov) for k, prov in all_providers.items()]
|
||||||
|
|
||||||
|
providers = [p for api, p in providers if api in self.providable_apis]
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
|
"API Type",
|
||||||
"Provider Type",
|
"Provider Type",
|
||||||
"PIP Package Dependencies",
|
"PIP Package Dependencies",
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for spec in providers_for_api.values():
|
|
||||||
if spec.provider_type == "sample":
|
specs = [spec for p in providers for spec in p.values()]
|
||||||
|
for spec in specs:
|
||||||
|
if spec.is_sample:
|
||||||
continue
|
continue
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
|
spec.api.value,
|
||||||
spec.provider_type,
|
spec.provider_type,
|
||||||
",".join(spec.pip_packages),
|
",".join(spec.pip_packages),
|
||||||
]
|
]
|
||||||
|
@ -59,4 +72,5 @@ class StackListProviders(Subcommand):
|
||||||
rows,
|
rows,
|
||||||
headers,
|
headers,
|
||||||
separate_rows=True,
|
separate_rows=True,
|
||||||
|
sort_by=(0, 1),
|
||||||
)
|
)
|
||||||
|
|
|
@ -65,6 +65,13 @@ class StackRun(Subcommand):
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to TLS certificate file for HTTPS",
|
help="Path to TLS certificate file for HTTPS",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--image-type",
|
||||||
|
type=str,
|
||||||
|
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||||
|
choices=["conda", "container", "venv"],
|
||||||
|
default="conda",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
|
@ -118,11 +125,11 @@ class StackRun(Subcommand):
|
||||||
config_dict = yaml.safe_load(config_file.read_text())
|
config_dict = yaml.safe_load(config_file.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
|
||||||
if config.container_image:
|
if args.image_type == ImageType.container.value or config.container_image:
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
||||||
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||||
run_args = [script, image_name]
|
run_args = [script, image_name]
|
||||||
else:
|
elif args.image_type == ImageType.conda.value:
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
image_name = args.image_name or current_conda_env
|
image_name = args.image_name or current_conda_env
|
||||||
if not image_name:
|
if not image_name:
|
||||||
|
@ -167,6 +174,15 @@ class StackRun(Subcommand):
|
||||||
script,
|
script,
|
||||||
image_name,
|
image_name,
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
# else must be venv since that is the only valid option left.
|
||||||
|
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||||
|
venv = args.image_name or current_venv
|
||||||
|
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
|
||||||
|
run_args = [
|
||||||
|
script,
|
||||||
|
venv,
|
||||||
|
]
|
||||||
|
|
||||||
run_args.extend([str(config_file), str(args.port)])
|
run_args.extend([str(config_file), str(args.port)])
|
||||||
if args.disable_ipv6:
|
if args.disable_ipv6:
|
||||||
|
|
|
@ -31,6 +31,8 @@ class StackParser(Subcommand):
|
||||||
version=f"{version('llama-stack')}",
|
version=f"{version('llama-stack')}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||||
|
|
||||||
subparsers = self.parser.add_subparsers(title="stack_subcommands")
|
subparsers = self.parser.add_subparsers(title="stack_subcommands")
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -39,11 +40,15 @@ def format_row(row, col_widths):
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def print_table(rows, headers=None, separate_rows: bool = False):
|
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
|
||||||
def itemlen(item):
|
def itemlen(item):
|
||||||
return max([len(line) for line in strip_ansi_colors(item).split("\n")])
|
return max([len(line) for line in strip_ansi_colors(item).split("\n")])
|
||||||
|
|
||||||
rows = [[x or "" for x in row] for row in rows]
|
rows = [[x or "" for x in row] for row in rows]
|
||||||
|
|
||||||
|
if sort_by:
|
||||||
|
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
|
||||||
|
|
||||||
if not headers:
|
if not headers:
|
||||||
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
|
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.distribution.configure import (
|
from llama_stack.distribution.configure import (
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
parse_and_maybe_upgrade_config,
|
parse_and_maybe_upgrade_config,
|
||||||
|
|
|
@ -8,7 +8,6 @@ import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
@ -16,11 +15,8 @@ from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
|
@ -5,18 +5,16 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, get_args, get_origin, Type, Union
|
from typing import Any, Type, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, parse_obj_as
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
_CLIENT_CLASSES = {}
|
_CLIENT_CLASSES = {}
|
||||||
|
|
|
@ -5,12 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
DistributionSpec,
|
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
|
DistributionSpec,
|
||||||
Provider,
|
Provider,
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
)
|
)
|
||||||
|
@ -20,7 +19,6 @@ from llama_stack.distribution.distribution import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -82,3 +82,6 @@ class DistributionInspectImpl(Inspect):
|
||||||
|
|
||||||
async def version(self) -> VersionInfo:
|
async def version(self) -> VersionInfo:
|
||||||
return VersionInfo(version=version("llama-stack"))
|
return VersionInfo(version=version("llama-stack"))
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
|
@ -13,10 +13,21 @@ import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, get_args, get_origin, Optional, TypeVar
|
from typing import Any, Optional, TypeVar, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
from llama_stack_client import (
|
||||||
|
NOT_GIVEN,
|
||||||
|
APIResponse,
|
||||||
|
AsyncAPIResponse,
|
||||||
|
AsyncLlamaStackClient,
|
||||||
|
AsyncStream,
|
||||||
|
LlamaStackClient,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
from rich.console import Console
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
@ -35,17 +46,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
setup_logger,
|
setup_logger,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
from llama_stack_client import (
|
|
||||||
APIResponse,
|
|
||||||
AsyncAPIResponse,
|
|
||||||
AsyncLlamaStackClient,
|
|
||||||
AsyncStream,
|
|
||||||
LlamaStackClient,
|
|
||||||
NOT_GIVEN,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel, TypeAdapter
|
|
||||||
from rich.console import Console
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||||
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
AppEvalTaskConfig,
|
AppEvalTaskConfig,
|
||||||
|
|
|
@ -537,3 +537,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
await self.unregister_object(tool)
|
await self.unregister_object(tool)
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
|
@ -10,11 +10,8 @@ from typing import Dict, List
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
@ -20,7 +21,8 @@ from pathlib import Path
|
||||||
from typing import Any, List, Union
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
@ -52,6 +54,9 @@ from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
log = file if hasattr(file, "write") else sys.stderr
|
log = file if hasattr(file, "write") else sys.stderr
|
||||||
|
@ -112,22 +117,70 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_sigint(app, *args, **kwargs):
|
def handle_signal(app, signum, _) -> None:
|
||||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
"""
|
||||||
|
Handle incoming signals and initiate a graceful shutdown of the application.
|
||||||
|
|
||||||
async def run_shutdown():
|
This function is intended to be used as a signal handler for various signals
|
||||||
|
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
||||||
|
indicating the received signal and initiate a shutdown process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The application instance containing implementations to be shut down.
|
||||||
|
signum (int): The signal number received.
|
||||||
|
frame: The current stack frame (not used in this function).
|
||||||
|
|
||||||
|
The shutdown process involves:
|
||||||
|
- Shutting down all implementations registered in the application.
|
||||||
|
- Gathering all running asyncio tasks.
|
||||||
|
- Cancelling all gathered tasks.
|
||||||
|
- Waiting for all tasks to finish.
|
||||||
|
- Stopping the event loop.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function schedules the shutdown process as an asyncio task and does
|
||||||
|
not block the current execution.
|
||||||
|
"""
|
||||||
|
signame = signal.Signals(signum).name
|
||||||
|
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||||
|
|
||||||
|
async def shutdown():
|
||||||
|
try:
|
||||||
|
# Gracefully shut down implementations
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
print(f"Shutting down {impl}")
|
impl_name = impl.__class__.__name__
|
||||||
await impl.shutdown()
|
logger.info("Shutting down %s", impl_name)
|
||||||
|
try:
|
||||||
|
if hasattr(impl, "shutdown"):
|
||||||
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
|
else:
|
||||||
|
logger.warning("No shutdown method for %s", impl_name)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||||
|
|
||||||
asyncio.run(run_shutdown())
|
# Gather all running tasks
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
# Cancel all tasks
|
||||||
for task in asyncio.all_tasks(loop):
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
# Wait for all tasks to finish
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.exception("Timeout while waiting for tasks to finish")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
loop.stop()
|
loop.stop()
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.create_task(shutdown())
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
@ -386,7 +439,8 @@ def main():
|
||||||
print("")
|
print("")
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||||
|
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
app.__llama_stack_impls__ = impls
|
||||||
|
|
||||||
|
|
71
llama_stack/distribution/start_venv.sh
Executable file
71
llama_stack/distribution/start_venv.sh
Executable file
|
@ -0,0 +1,71 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
error_handler() {
|
||||||
|
echo "Error occurred in script at line: ${1}" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'error_handler ${LINENO}' ERR
|
||||||
|
|
||||||
|
if [ $# -lt 3 ]; then
|
||||||
|
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
venv_path="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
yaml_config="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
port="$1"
|
||||||
|
shift
|
||||||
|
|
||||||
|
# Initialize env_vars as an empty array
|
||||||
|
env_vars=""
|
||||||
|
other_args=""
|
||||||
|
# Process environment variables from --env arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--env)
|
||||||
|
|
||||||
|
if [[ -n "$2" ]]; then
|
||||||
|
env_vars="$env_vars --env $2"
|
||||||
|
shift 2
|
||||||
|
else
|
||||||
|
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
other_args="$other_args $1"
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
if [ ! -d "$venv_path" ]; then
|
||||||
|
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
source "$venv_path/bin/activate"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
python -m llama_stack.distribution.server.server \
|
||||||
|
--yaml-config "$yaml_config" \
|
||||||
|
--port "$port" \
|
||||||
|
$env_vars \
|
||||||
|
$other_args
|
|
@ -8,9 +8,9 @@ import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.inference import Model
|
from llama_stack.apis.inference import Model
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
|
||||||
from llama_stack.distribution.store.registry import (
|
from llama_stack.distribution.store.registry import (
|
||||||
CachedDiskDistributionRegistry,
|
CachedDiskDistributionRegistry,
|
||||||
DiskDistributionRegistry,
|
DiskDistributionRegistry,
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
|
@ -10,7 +10,6 @@ from page.distribution.models import models
|
||||||
from page.distribution.scoring_functions import scoring_functions
|
from page.distribution.scoring_functions import scoring_functions
|
||||||
from page.distribution.shields import shields
|
from page.distribution.shields import shields
|
||||||
from page.distribution.vector_dbs import vector_dbs
|
from page.distribution.vector_dbs import vector_dbs
|
||||||
|
|
||||||
from streamlit_option_menu import option_menu
|
from streamlit_option_menu import option_menu
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from modules.api import llama_stack_api
|
from modules.api import llama_stack_api
|
||||||
from modules.utils import process_dataset
|
from modules.utils import process_dataset
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,7 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from modules.api import llama_stack_api
|
from modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
|
|
||||||
from modules.api import llama_stack_api
|
from modules.api import llama_stack_api
|
||||||
from modules.utils import data_url_from_file
|
from modules.utils import data_url_from_file
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
||||||
|
|
||||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||||
|
|
|
@ -8,13 +8,11 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
|
||||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefinedType
|
from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.eval_tasks import EvalTask
|
from llama_stack.apis.eval_tasks import EvalTask
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
@ -86,6 +85,10 @@ class ProviderSpec(BaseModel):
|
||||||
# used internally by the resolver; this is a hack for now
|
# used internally by the resolver; this is a hack for now
|
||||||
deps__: List[str] = Field(default_factory=list)
|
deps__: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_sample(self) -> bool:
|
||||||
|
return self.provider_type in ("sample", "remote::sample")
|
||||||
|
|
||||||
|
|
||||||
class RoutingTable(Protocol):
|
class RoutingTable(Protocol):
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
|
@ -42,10 +42,10 @@ from llama_stack.apis.agents import (
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if delta.type == "tool_call":
|
if delta.type == "tool_call":
|
||||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
tool_calls.append(delta.tool_call)
|
tool_calls.append(delta.tool_call)
|
||||||
|
elif delta.parse_status == ToolCallParseStatus.failed:
|
||||||
|
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||||
|
content = delta.tool_call
|
||||||
if stream:
|
if stream:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
|
|
@ -81,12 +81,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
) -> AgentCreateResponse:
|
) -> AgentCreateResponse:
|
||||||
agent_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
if agent_config.tool_config is None:
|
|
||||||
agent_config.tool_config = ToolConfig(
|
|
||||||
tool_choice=agent_config.tool_choice,
|
|
||||||
tool_prompt_format=agent_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.persistence_store.set(
|
await self.persistence_store.set(
|
||||||
key=f"agent:{agent_id}",
|
key=f"agent:{agent_id}",
|
||||||
value=agent_config.model_dump_json(),
|
value=agent_config.model_dump_json(),
|
||||||
|
@ -218,3 +212,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
async def delete_agent(self, agent_id: str) -> None:
|
async def delete_agent(self, agent_id: str) -> None:
|
||||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
|
@ -6,11 +6,9 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
|
@ -41,7 +41,6 @@ from llama_stack.apis.tools import (
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
MEMORY_QUERY_TOOL,
|
MEMORY_QUERY_TOOL,
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,14 +15,12 @@ import pandas
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
|
||||||
|
|
||||||
DATASETS_PREFIX = "localfs_datasets:"
|
DATASETS_PREFIX = "localfs_datasets:"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.inference import Inference, UserMessage
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
MEMORY_QUERY_TOOL,
|
MEMORY_QUERY_TOOL,
|
||||||
)
|
)
|
||||||
|
@ -28,7 +27,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .....apis.common.job_types import Job
|
from .....apis.common.job_types import Job
|
||||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||||
|
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
EVAL_TASKS_PREFIX = "eval_tasks:"
|
EVAL_TASKS_PREFIX = "eval_tasks:"
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationConfig
|
from llama_stack.apis.inference import QuantizationConfig
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
CrossAttentionTransformer,
|
CrossAttentionTransformer,
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -47,7 +46,6 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
|
|
|
@ -46,8 +46,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
build_model_alias,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
augment_content_with_response_format_prompt,
|
augment_content_with_response_format_prompt,
|
||||||
|
|
|
@ -22,16 +22,13 @@ from typing import Callable, Generator, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.initialize import (
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_group,
|
get_model_parallel_group,
|
||||||
get_model_parallel_rank,
|
get_model_parallel_rank,
|
||||||
get_model_parallel_src_rank,
|
get_model_parallel_src_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
@ -23,7 +22,7 @@ except ImportError:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class Fp8ScaledWeights:
|
class Fp8ScaledWeights:
|
||||||
|
|
|
@ -10,9 +10,9 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
|
||||||
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
|
from hypothesis import given, settings
|
||||||
from hypothesis import given, settings, strategies as st
|
from hypothesis import strategies as st
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,18 +12,13 @@ import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
|
|
||||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from torch import Tensor, nn
|
||||||
from torch import nn, Tensor
|
|
||||||
|
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
|
@ -16,14 +16,12 @@ from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import (
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_rank,
|
get_model_parallel_rank,
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
|
|
|
@ -15,9 +15,9 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
ToolConfig,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
|
|
@ -37,9 +37,9 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
@ -201,7 +201,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
response = OpenAICompatCompletionResponse(
|
response = OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(response, self.formatter)
|
return process_chat_completion_response(response, self.formatter, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||||
|
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
|
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||||
|
|
|
@ -15,10 +15,8 @@ from typing import Any, Callable, Dict
|
||||||
import torch
|
import torch
|
||||||
from llama_models.datatypes import Model
|
from llama_models.datatypes import Model
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||||
|
|
||||||
from torchtune.models.llama3 import llama3_tokenizer
|
from torchtune.models.llama3 import llama3_tokenizer
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
from typing import Any, Dict, List, Mapping
|
from typing import Any, Dict, List, Mapping
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||||
from torchtune.data._messages import validate_messages
|
from torchtune.data._messages import validate_messages
|
||||||
|
|
|
@ -18,9 +18,9 @@ from llama_models.sku_list import resolve_model
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training, utils as torchtune_utils
|
from torchtune import modules, training
|
||||||
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.data import padded_collate_sft
|
from torchtune.data import padded_collate_sft
|
||||||
|
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
from torchtune.modules.peft import (
|
from torchtune.modules.peft import (
|
||||||
get_adapter_params,
|
get_adapter_params,
|
||||||
|
@ -44,14 +44,11 @@ from llama_stack.apis.post_training import (
|
||||||
OptimizerConfig,
|
OptimizerConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.inline.post_training.common.validator import (
|
from llama_stack.providers.inline.post_training.common.validator import (
|
||||||
validate_input_dataset_schema,
|
validate_input_dataset_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||||
TorchtuneCheckpointer,
|
TorchtuneCheckpointer,
|
||||||
|
|
|
@ -21,7 +21,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -25,10 +24,8 @@ from llama_stack.apis.safety import (
|
||||||
SafetyViolation,
|
SafetyViolation,
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import LlamaGuardConfig
|
from .config import LlamaGuardConfig
|
||||||
|
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
SAFE_RESPONSE = "safe"
|
SAFE_RESPONSE = "safe"
|
||||||
|
|
|
@ -8,7 +8,6 @@ import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
@ -19,7 +18,6 @@ from llama_stack.apis.safety import (
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
|
@ -14,13 +14,13 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringResult,
|
ScoringResult,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
get_valid_schemas,
|
get_valid_schemas,
|
||||||
validate_dataset_schema,
|
validate_dataset_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import BasicScoringConfig
|
from .config import BasicScoringConfig
|
||||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
equality = ScoringFn(
|
equality = ScoringFn(
|
||||||
identifier="basic::equality",
|
identifier="basic::equality",
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
subset_of = ScoringFn(
|
subset_of = ScoringFn(
|
||||||
identifier="basic::subset_of",
|
identifier="basic::subset_of",
|
||||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
|
@ -29,9 +29,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringResultRow,
|
ScoringResultRow,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
|
@ -39,8 +37,8 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
validate_dataset_schema,
|
validate_dataset_schema,
|
||||||
validate_row_schema,
|
validate_row_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||||
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
|
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
answer_correctness_fn_def = ScoringFn(
|
answer_correctness_fn_def = ScoringFn(
|
||||||
identifier="braintrust::answer-correctness",
|
identifier="braintrust::answer-correctness",
|
||||||
description=(
|
description=(
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
factuality_fn_def = ScoringFn(
|
factuality_fn_def = ScoringFn(
|
||||||
identifier="braintrust::factuality",
|
identifier="braintrust::factuality",
|
||||||
description=(
|
description=(
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
|
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
ScoreResponse,
|
ScoreResponse,
|
||||||
|
@ -26,7 +25,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
from .config import LlmAsJudgeScoringConfig
|
from .config import LlmAsJudgeScoringConfig
|
||||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||||
|
|
||||||
|
|
||||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
|
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
|
||||||
|
|
||||||
|
|
||||||
llm_as_judge_base = ScoringFn(
|
llm_as_judge_base = ScoringFn(
|
||||||
identifier="llm-as-judge::base",
|
identifier="llm-as-judge::base",
|
||||||
description="Llm As Judge Scoring Function",
|
description="Llm As Judge Scoring Function",
|
||||||
|
|
|
@ -4,18 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
|
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
|
||||||
|
|
||||||
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
|
||||||
from .config import SampleConfig
|
from .config import SampleConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,11 @@ import sys as _sys
|
||||||
# them with linters - they're used in code_execution.py
|
# them with linters - they're used in code_execution.py
|
||||||
from contextlib import ( # noqa
|
from contextlib import ( # noqa
|
||||||
contextmanager as _contextmanager,
|
contextmanager as _contextmanager,
|
||||||
|
)
|
||||||
|
from contextlib import (
|
||||||
redirect_stderr as _redirect_stderr,
|
redirect_stderr as _redirect_stderr,
|
||||||
|
)
|
||||||
|
from contextlib import (
|
||||||
redirect_stdout as _redirect_stdout,
|
redirect_stdout as _redirect_stdout,
|
||||||
)
|
)
|
||||||
from multiprocessing.connection import Connection as _Connection
|
from multiprocessing.connection import Connection as _Connection
|
||||||
|
|
|
@ -9,7 +9,6 @@ from jinja2 import Template
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import UserMessage
|
from llama_stack.apis.inference import UserMessage
|
||||||
|
|
||||||
from llama_stack.apis.tools.rag_tool import (
|
from llama_stack.apis.tools.rag_tool import (
|
||||||
DefaultRAGQueryGeneratorConfig,
|
DefaultRAGQueryGeneratorConfig,
|
||||||
LLMRAGQueryGeneratorConfig,
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
|
|
@ -11,9 +11,9 @@ import string
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
|
|
@ -8,10 +8,10 @@ from typing import Dict
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import ChromaInlineImplConfig
|
from .config import ChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||||
ChromaVectorIOAdapter,
|
ChromaVectorIOAdapter,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any, Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ChromaInlineImplConfig(BaseModel):
|
class ChromaVectorIOConfig(BaseModel):
|
||||||
db_path: str
|
db_path: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -7,14 +7,15 @@
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from .config import FaissImplConfig
|
|
||||||
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .faiss import FaissVectorIOImpl
|
from .faiss import FaissVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, FaissImplConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = FaissVectorIOImpl(config, deps[Api.inference])
|
impl = FaissVectorIOAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FaissImplConfig(BaseModel):
|
class FaissVectorIOConfig(BaseModel):
|
||||||
kvstore: KVStoreConfig
|
kvstore: KVStoreConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -8,11 +8,9 @@ import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
@ -26,7 +24,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import FaissImplConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -114,8 +112,8 @@ class FaissIndex(EmbeddingIndex):
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import SQLiteVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
|
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||||
|
|
||||||
|
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
29
llama_stack/providers/inline/vector_io/sqlite_vec/config.py
Normal file
29
llama_stack/providers/inline/vector_io/sqlite_vec/config.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# config.py
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteVectorIOConfig(BaseModel):
|
||||||
|
db_path: str
|
||||||
|
kvstore: KVStoreConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="sqlite_vec.db",
|
||||||
|
)
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue