Merge branch 'main' into jwm4-add-qdrant-to-provider-tests

This commit is contained in:
Bill Murdock 2025-02-13 14:08:46 -05:00 committed by GitHub
commit 89411c839a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
198 changed files with 1432 additions and 646 deletions

View file

@ -8,4 +8,3 @@
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)

View file

@ -29,10 +29,12 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4
hooks:
# Run the linter with import sorting.
- id: ruff
args: [
--fix,
--exit-non-zero-on-fix
--exit-non-zero-on-fix,
--select, I,
]
- id: ruff-format

View file

@ -40,6 +40,7 @@ If you need help or guidance, comment on the issue. Issues that are extra friend
3. Ensure the test suite passes.
4. Make sure your code lints using `pre-commit`.
5. If you haven't already, complete the Contributor License Agreement ("CLA").
6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/).
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
@ -98,7 +99,8 @@ $ uv sync
```
## Coding Style
* 2 spaces for indentation rather than tabs
* 4 spaces for indentation rather than tabs
* 80 character line length
* ...

View file

@ -7,13 +7,13 @@
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
Llama Stack defines and standardizes the core building blocks that simplify AI application development. It codified best practices across the Llama ecosystem. More specifically, it provides
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.
- **Plugin architecture** to support the rich ecosystem of implementations of the different APIs in different environments like local development, on-premises, cloud, and mobile.
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
- **Plugin architecture** to support the rich ecosystem of different API implementations in various environments, including local development, on-premises, cloud, and mobile.
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment.
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android.
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack.
<div style="text-align: center;">
<img
@ -25,14 +25,14 @@ Llama Stack defines and standardizes the core building blocks that simplify AI a
</div>
### Llama Stack Benefits
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choice.
- **Consistent Experience**: With its unified APIs Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choices.
- **Consistent Experience**: With its unified APIs, Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
- **Robust Ecosystem**: Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies) that offer tailored infrastructure, software, and services for deploying Llama models.
By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications.
### API Providers
Here is a list of the various API providers and available distributions to developers started easily,
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
@ -71,15 +71,15 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
You have two ways to install this repository:
1. **Install as a package**:
* **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
2. **Install from source**:
* **Install from source**:
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
Then, follow these steps:
Then, run the following commands:
```bash
mkdir -p ~/local
cd ~/local
@ -96,10 +96,11 @@ You have two ways to install this repository:
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
* Quick guide to start a Llama Stack server.
* CLI references
* [llama (server-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html): Guide for using the `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [llama (client-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_stack_client_cli_reference.html): Guide for using the `llama-stack-client` CLI, which allows you to query information about the distribution.
* Getting Started
* [Quick guide to start a Llama Stack server](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
@ -115,6 +116,6 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest
| Typescript | [llama-stack-client-typescript](https://github.com/meta-llama/llama-stack-client-typescript) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo.

View file

@ -3106,6 +3106,12 @@
"ChatCompletionResponse": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricEvent"
}
},
"completion_message": {
"$ref": "#/components/schemas/CompletionMessage",
"description": "The complete response message"
@ -3124,6 +3130,74 @@
],
"description": "Response from a chat completion request."
},
"MetricEvent": {
"type": "object",
"properties": {
"trace_id": {
"type": "string"
},
"span_id": {
"type": "string"
},
"timestamp": {
"type": "string",
"format": "date-time"
},
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
}
},
"type": {
"type": "string",
"const": "metric",
"default": "metric"
},
"metric": {
"type": "string"
},
"value": {
"oneOf": [
{
"type": "integer"
},
{
"type": "number"
}
]
},
"unit": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"trace_id",
"span_id",
"timestamp",
"type",
"metric",
"value",
"unit"
]
},
"TokenLogProbs": {
"type": "object",
"properties": {
@ -3388,6 +3462,12 @@
"ChatCompletionResponseStreamChunk": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricEvent"
}
},
"event": {
"$ref": "#/components/schemas/ChatCompletionResponseEvent",
"description": "The event containing the new content"
@ -3600,8 +3680,7 @@
"auto",
"required"
],
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.",
"default": "auto"
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
"tool_prompt_format": {
"type": "string",
@ -6374,77 +6453,6 @@
"critical"
]
},
"MetricEvent": {
"type": "object",
"properties": {
"trace_id": {
"type": "string"
},
"span_id": {
"type": "string"
},
"timestamp": {
"type": "string",
"format": "date-time"
},
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"type": {
"type": "string",
"const": "metric",
"default": "metric"
},
"metric": {
"type": "string"
},
"value": {
"oneOf": [
{
"type": "integer"
},
{
"type": "number"
}
]
},
"unit": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"trace_id",
"span_id",
"timestamp",
"type",
"metric",
"value",
"unit"
]
},
"SpanEndPayload": {
"type": "object",
"properties": {
@ -6502,22 +6510,19 @@
"additionalProperties": {
"oneOf": [
{
"type": "null"
"type": "string"
},
{
"type": "boolean"
"type": "integer"
},
{
"type": "number"
},
{
"type": "string"
"type": "boolean"
},
{
"type": "array"
},
{
"type": "object"
"type": "null"
}
]
}
@ -6575,22 +6580,19 @@
"additionalProperties": {
"oneOf": [
{
"type": "null"
"type": "string"
},
{
"type": "boolean"
"type": "integer"
},
{
"type": "number"
},
{
"type": "string"
"type": "boolean"
},
{
"type": "array"
},
{
"type": "object"
"type": "null"
}
]
}

View file

@ -1925,6 +1925,10 @@ components:
ChatCompletionResponse:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricEvent'
completion_message:
$ref: '#/components/schemas/CompletionMessage'
description: The complete response message
@ -1938,6 +1942,46 @@ components:
required:
- completion_message
description: Response from a chat completion request.
MetricEvent:
type: object
properties:
trace_id:
type: string
span_id:
type: string
timestamp:
type: string
format: date-time
attributes:
type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
type:
type: string
const: metric
default: metric
metric:
type: string
value:
oneOf:
- type: integer
- type: number
unit:
type: string
additionalProperties: false
required:
- trace_id
- span_id
- timestamp
- type
- metric
- value
- unit
TokenLogProbs:
type: object
properties:
@ -2173,6 +2217,10 @@ components:
ChatCompletionResponseStreamChunk:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricEvent'
event:
$ref: '#/components/schemas/ChatCompletionResponseEvent'
description: The event containing the new content
@ -2338,7 +2386,6 @@ components:
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following capabilities
of the model.
default: auto
tool_prompt_format:
type: string
enum:
@ -4070,47 +4117,6 @@ components:
- warn
- error
- critical
MetricEvent:
type: object
properties:
trace_id:
type: string
span_id:
type: string
timestamp:
type: string
format: date-time
attributes:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type:
type: string
const: metric
default: metric
metric:
type: string
value:
oneOf:
- type: integer
- type: number
unit:
type: string
additionalProperties: false
required:
- trace_id
- span_id
- timestamp
- type
- metric
- value
- unit
SpanEndPayload:
type: object
properties:
@ -4153,12 +4159,11 @@ components:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
- type: integer
- type: number
- type: boolean
- type: 'null'
type:
type: string
const: structured_log
@ -4195,12 +4200,11 @@ components:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
- type: integer
- type: number
- type: boolean
- type: 'null'
type:
type: string
const: unstructured_log

View file

@ -644,7 +644,9 @@ class Generator:
else:
callbacks = None
description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description]))
description = "\n".join(
filter(None, [doc_string.short_description, doc_string.long_description])
)
return Operation(
tags=[op.defining_class.__name__],
summary=None,
@ -681,6 +683,7 @@ class Generator:
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
route = op.get_route()
route = route.replace(":path", "")
print(f"route: {route}")
if route in paths:
paths[route].update(pathItem)

View file

@ -130,6 +130,8 @@ class _FormatParameterExtractor:
def _get_route_parameters(route: str) -> List[str]:
extractor = _FormatParameterExtractor()
# Replace all occurrences of ":path" with empty string
route = route.replace(":path", "")
route.format_map(extractor)
return extractor.keys

View file

@ -180,12 +180,45 @@ After this step is successful, you should be able to find the built container im
### Running your Stack server
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step.
```
llama stack run -h
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE]
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
config
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments:
config Path to config file to use for the run
options:
-h, --help show this help message and exit
--port PORT Port to run the server on. Defaults to 8321
--image-name IMAGE_NAME
Name of the image to run. Defaults to the current conda environment
--disable-ipv6 Disable IPv6 support
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.
--tls-keyfile TLS_KEYFILE
Path to TLS key file for HTTPS
--tls-certfile TLS_CERTFILE
Path to TLS certificate file for HTTPS
--image-type {conda,container,venv}
Image Type used during the build. This can be either conda or container or venv.
```
```
# Start using template name
llama stack run tgi
# Start using config file
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
# Start using a venv
llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
# Start using a conda environment
llama stack run --image-type conda ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
```
```

View file

@ -15,25 +15,25 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None)
@ -166,11 +166,13 @@ class AgentConfigCommon(BaseModel):
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
if self.tool_config is None:
self.tool_config = ToolConfig(
tool_choice=self.tool_choice,
tool_prompt_format=self.tool_prompt_format,
)
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type
@ -333,7 +335,10 @@ class Agents(Protocol):
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
)
async def get_agents_turn(
self,
agent_id: str,

View file

@ -13,7 +13,6 @@ from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)

View file

@ -8,7 +8,6 @@ from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator

View file

@ -8,7 +8,6 @@ from enum import Enum
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL

View file

@ -58,7 +58,7 @@ class Datasets(Protocol):
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/datasets/{dataset_id}", method="GET")
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
async def get_dataset(
self,
dataset_id: str,
@ -67,7 +67,7 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -13,8 +13,8 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.llama3.api.datatypes import (
@ -31,6 +31,7 @@ from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -357,7 +358,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -367,7 +368,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
@json_schema_type
class ChatCompletionResponse(BaseModel):
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
"""Response from a chat completion request.
:param completion_message: The complete response message

View file

@ -62,7 +62,7 @@ class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
@webmethod(route="/models/{model_id}", method="GET")
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,
model_id: str,
@ -78,7 +78,7 @@ class Models(Protocol):
model_type: Optional[ModelType] = None,
) -> Model: ...
@webmethod(route="/models/{model_id}", method="DELETE")
@webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model(
self,
model_id: str,

View file

@ -12,8 +12,8 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
@ -134,7 +134,7 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST")

View file

@ -48,7 +48,7 @@ class Shields(Protocol):
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier}", method="GET")
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
@webmethod(route="/shields", method="POST")

View file

@ -5,11 +5,9 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.inference import Message

View file

@ -13,10 +13,11 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.llama3.api.datatypes import Primitive
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -76,7 +77,7 @@ class EventCommon(BaseModel):
trace_id: str
span_id: str
timestamp: datetime
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
@json_schema_type
@ -94,6 +95,30 @@ class MetricEvent(EventCommon):
unit: str
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
@json_schema_type
class StructuredLogType(Enum):
SPAN_START = "span_start"
@ -199,13 +224,13 @@ class Telemetry(Protocol):
order_by: Optional[List[str]] = None,
) -> QueryTracesResponse: ...
@webmethod(route="/telemetry/traces/{trace_id}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
@webmethod(route="/telemetry/spans/{span_id}/tree", method="GET")
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
async def get_span_tree(
self,
span_id: str,

View file

@ -4,5 +4,5 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .tools import * # noqa: F401 F403
from .rag_tool import * # noqa: F401 F403
from .tools import * # noqa: F401 F403

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -101,7 +101,7 @@ class ToolGroups(Protocol):
"""Register a tool group"""
...
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group(
self,
toolgroup_id: str,
@ -117,13 +117,13 @@ class ToolGroups(Protocol):
"""List tools with optional tool group"""
...
@webmethod(route="/tools/{tool_name}", method="GET")
@webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool(
self,
tool_name: str,
) -> Tool: ...
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup(
self,
toolgroup_id: str,

View file

@ -46,7 +46,7 @@ class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db(
self,
vector_db_id: str,
@ -62,5 +62,5 @@ class VectorDBs(Protocol):
provider_vector_db_id: Optional[str] = None,
) -> VectorDB: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ...

View file

@ -16,11 +16,9 @@ from pathlib import Path
from typing import Dict, List, Optional
import httpx
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict
from rich.console import Console
from rich.progress import (
BarColumn,

View file

@ -8,7 +8,6 @@ import argparse
import json
from llama_models.sku_list import resolve_model
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand

View file

@ -38,7 +38,7 @@ class ModelList(Subcommand):
headers = [
"Model Descriptor",
"Hugging Face Repo",
"Model ID",
"Context Length",
]

View file

@ -11,7 +11,6 @@ from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand
@ -26,6 +25,8 @@ class ModelParser(Subcommand):
description="Work with llama models",
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="model_subcommands")
# Add sub-commands

View file

@ -8,7 +8,7 @@ import argparse
import textwrap
from io import StringIO
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily
from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
from llama_stack.cli.subcommand import Subcommand

View file

@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.datatypes import SamplingParams
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict, Field

View file

@ -21,12 +21,11 @@ from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.cli.table import print_table
from llama_stack.distribution.build import (
SERVER_DEPENDENCIES,
ImageType,
build_image,
get_provider_dependencies,
ImageType,
SERVER_DEPENDENCIES,
)
from llama_stack.distribution.datatypes import (
BuildConfig,

View file

@ -21,15 +21,19 @@ class StackListProviders(Subcommand):
self._add_arguments()
self.parser.set_defaults(func=self._run_providers_list_cmd)
def _add_arguments(self):
@property
def providable_apis(self):
from llama_stack.distribution.distribution import providable_apis
api_values = [api.value for api in providable_apis()]
return [api.value for api in providable_apis()]
def _add_arguments(self):
self.parser.add_argument(
"api",
type=str,
choices=api_values,
help="API to list providers for (one of: {})".format(api_values),
choices=self.providable_apis,
nargs="?",
help="API to list providers for. List all if not specified.",
)
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
@ -37,20 +41,29 @@ class StackListProviders(Subcommand):
from llama_stack.distribution.distribution import Api, get_provider_registry
all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)]
if args.api:
providers = [(args.api, all_providers[Api(args.api)])]
else:
providers = [(k.value, prov) for k, prov in all_providers.items()]
providers = [p for api, p in providers if api in self.providable_apis]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"API Type",
"Provider Type",
"PIP Package Dependencies",
]
rows = []
for spec in providers_for_api.values():
if spec.provider_type == "sample":
specs = [spec for p in providers for spec in p.values()]
for spec in specs:
if spec.is_sample:
continue
rows.append(
[
spec.api.value,
spec.provider_type,
",".join(spec.pip_packages),
]
@ -59,4 +72,5 @@ class StackListProviders(Subcommand):
rows,
headers,
separate_rows=True,
sort_by=(0, 1),
)

View file

@ -65,6 +65,13 @@ class StackRun(Subcommand):
type=str,
help="Path to TLS certificate file for HTTPS",
)
self.parser.add_argument(
"--image-type",
type=str,
help="Image Type used during the build. This can be either conda or container or venv.",
choices=["conda", "container", "venv"],
default="conda",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources
@ -118,11 +125,11 @@ class StackRun(Subcommand):
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if config.container_image:
if args.image_type == ImageType.container.value or config.container_image:
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
image_name = f"distribution-{template_name}" if template_name else config.container_image
run_args = [script, image_name]
else:
elif args.image_type == ImageType.conda.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
if not image_name:
@ -167,6 +174,15 @@ class StackRun(Subcommand):
script,
image_name,
]
else:
# else must be venv since that is the only valid option left.
current_venv = os.environ.get("VIRTUAL_ENV")
venv = args.image_name or current_venv
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
run_args = [
script,
venv,
]
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:

View file

@ -31,6 +31,8 @@ class StackParser(Subcommand):
version=f"{version('llama-stack')}",
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands

View file

@ -6,6 +6,7 @@
import re
import textwrap
from typing import Iterable
from termcolor import cprint
@ -39,11 +40,15 @@ def format_row(row, col_widths):
return "\n".join(lines)
def print_table(rows, headers=None, separate_rows: bool = False):
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
def itemlen(item):
return max([len(line) for line in strip_ansi_colors(item).split("\n")])
rows = [[x or "" for x in row] for row in rows]
if sort_by:
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
if not headers:
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
else:

View file

@ -8,6 +8,7 @@ from datetime import datetime
import pytest
import yaml
from llama_stack.distribution.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config,

View file

@ -8,7 +8,6 @@ import importlib.resources
import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Dict, List
@ -16,11 +15,8 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.providers.datatypes import Api

View file

@ -5,18 +5,16 @@
# the root directory of this source tree.
import inspect
import json
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, get_args, get_origin, Type, Union
from typing import Any, Type, Union, get_args, get_origin
import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}

View file

@ -5,12 +5,11 @@
# the root directory of this source tree.
import logging
import textwrap
from typing import Any, Dict
from llama_stack.distribution.datatypes import (
DistributionSpec,
LLAMA_STACK_RUN_CONFIG_VERSION,
DistributionSpec,
Provider,
StackRunConfig,
)
@ -20,7 +19,6 @@ from llama_stack.distribution.distribution import (
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)

View file

@ -82,3 +82,6 @@ class DistributionInspectImpl(Inspect):
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))
async def shutdown(self) -> None:
pass

View file

@ -13,10 +13,21 @@ import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, get_args, get_origin, Optional, TypeVar
from typing import Any, Optional, TypeVar, get_args, get_origin
import httpx
import yaml
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
@ -35,17 +46,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
setup_logger,
start_trace,
)
from llama_stack_client import (
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
NOT_GIVEN,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
T = TypeVar("T")

View file

@ -7,7 +7,6 @@
from typing import Any, Dict
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,

View file

@ -537,3 +537,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
pass

View file

@ -10,11 +10,8 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api

View file

@ -9,6 +9,7 @@ import asyncio
import functools
import inspect
import json
import logging
import os
import signal
import sys
@ -20,7 +21,8 @@ from pathlib import Path
from typing import Any, List, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
@ -52,6 +54,9 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__)
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
@ -112,21 +117,69 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
def handle_signal(app, signum, _) -> None:
"""
Handle incoming signals and initiate a graceful shutdown of the application.
async def run_shutdown():
for impl in app.__llama_stack_impls__.values():
print(f"Shutting down {impl}")
await impl.shutdown()
This function is intended to be used as a signal handler for various signals
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
asyncio.run(run_shutdown())
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
loop.stop()
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
@asynccontextmanager
@ -386,7 +439,8 @@ def main():
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls

View file

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
exit 1
fi
venv_path="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
# Initialize env_vars as an empty array
env_vars=""
other_args=""
# Process environment variables from --env arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
# Activate virtual environment
if [ ! -d "$venv_path" ]; then
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
exit 1
fi
source "$venv_path/bin/activate"
set -x
python -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args

View file

@ -8,9 +8,9 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry,
DiskDistributionRegistry,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
from typing import Optional
from llama_stack_client import LlamaStackClient

View file

@ -10,7 +10,6 @@ from page.distribution.models import models
from page.distribution.scoring_functions import scoring_functions
from page.distribution.shields import shields
from page.distribution.vector_dbs import vector_dbs
from streamlit_option_menu import option_menu

View file

@ -8,7 +8,6 @@ import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
from modules.utils import process_dataset

View file

@ -7,9 +7,7 @@
import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api

View file

@ -9,7 +9,6 @@ from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file

View file

@ -7,7 +7,6 @@
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"

View file

@ -8,13 +8,11 @@ import inspect
import json
import logging
from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
log = logging.getLogger(__name__)

View file

@ -11,7 +11,6 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.datatypes import Api
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.models import Model
@ -86,6 +85,10 @@ class ProviderSpec(BaseModel):
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)
@property
def is_sample(self) -> bool:
return self.provider_type in ("sample", "remote::sample")
class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...

View file

@ -42,10 +42,10 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.common.content_types import (
URL,
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin):
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.tool_call)
elif delta.parse_status == ToolCallParseStatus.failed:
# If we cannot parse the tools, set the content to the unparsed raw text
content = delta.tool_call
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(

View file

@ -81,12 +81,6 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
if agent_config.tool_config is None:
agent_config.tool_config = ToolConfig(
tool_choice=agent_config.tool_choice,
tool_prompt_format=agent_config.tool_prompt_format,
)
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
@ -218,3 +212,6 @@ class MetaReferenceAgentsImpl(Agents):
async def delete_agent(self, agent_id: str) -> None:
await self.persistence_store.delete(f"agent:{agent_id}")
async def shutdown(self) -> None:
pass

View file

@ -6,11 +6,9 @@
import asyncio
import logging
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
log = logging.getLogger(__name__)

View file

@ -41,7 +41,6 @@ from llama_stack.apis.tools import (
ToolInvocationResult,
)
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)

View file

@ -15,14 +15,12 @@ import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:"

View file

@ -15,7 +15,6 @@ from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
@ -28,7 +27,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"

View file

@ -9,7 +9,6 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
@ -47,7 +46,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
ResponseFormatType,
)
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,

View file

@ -46,8 +46,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
build_model_alias,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,

View file

@ -22,16 +22,13 @@ from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_group,
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -8,7 +8,6 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
import logging
from typing import Optional, Type
@ -23,7 +22,7 @@ except ImportError:
raise
import torch
from torch import nn, Tensor
from torch import Tensor, nn
class Fp8ScaledWeights:

View file

@ -10,9 +10,9 @@
import unittest
import torch
from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8
from hypothesis import given, settings, strategies as st
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
from hypothesis import given, settings
from hypothesis import strategies as st
from torch import Tensor

View file

@ -12,18 +12,13 @@ import os
from typing import Any, Dict, List, Optional
import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from torch import nn, Tensor
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType

View file

@ -16,14 +16,12 @@ from pathlib import Path
from typing import Optional
import fire
import torch
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock

View file

@ -15,9 +15,9 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolConfig,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (

View file

@ -37,9 +37,9 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
@ -201,7 +201,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter)
return process_chat_completion_response(response, self.formatter, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:

View file

@ -15,10 +15,8 @@ from typing import Any, Callable, Dict
import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b

View file

@ -13,7 +13,6 @@
from typing import Any, Dict, List, Mapping
import numpy as np
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages

View file

@ -18,9 +18,9 @@ from llama_models.sku_list import resolve_model
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
@ -44,14 +44,11 @@ from llama_stack.apis.post_training import (
OptimizerConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,

View file

@ -21,7 +21,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CodeScannerConfig
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import re
from string import Template
from typing import Any, Dict, List, Optional
@ -25,10 +24,8 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import LlamaGuardConfig
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe"

View file

@ -8,7 +8,6 @@ import logging
from typing import Any, Dict, List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import Message
@ -19,7 +18,6 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -14,13 +14,13 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn

View file

@ -7,7 +7,6 @@
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
subset_of = ScoringFn(
identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow

View file

@ -29,9 +29,7 @@ from llama_stack.apis.scoring import (
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
@ -39,8 +37,8 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema,
validate_row_schema,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description=(

View file

@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import (
ScoringFn,
)
factuality_fn_def = ScoringFn(
identifier="braintrust::factuality",
description=(

View file

@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
@ -26,7 +25,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]

View file

@ -7,7 +7,6 @@
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn(
identifier="llm-as-judge::base",
description="Llm As Judge Scoring Function",

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
from .fn_defs.llm_as_judge_base import llm_as_judge_base

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.telemetry import Telemetry
from .config import SampleConfig

View file

@ -82,7 +82,11 @@ import sys as _sys
# them with linters - they're used in code_execution.py
from contextlib import ( # noqa
contextmanager as _contextmanager,
)
from contextlib import (
redirect_stderr as _redirect_stderr,
)
from contextlib import (
redirect_stdout as _redirect_stdout,
)
from multiprocessing.connection import Connection as _Connection

View file

@ -9,7 +9,6 @@ from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,

View file

@ -11,9 +11,9 @@ import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (

View file

@ -8,10 +8,10 @@ from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaInlineImplConfig
from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)

View file

@ -9,7 +9,7 @@ from typing import Any, Dict
from pydantic import BaseModel
class ChromaInlineImplConfig(BaseModel):
class ChromaVectorIOConfig(BaseModel):
db_path: str
@classmethod

View file

@ -7,14 +7,15 @@
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig
from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissVectorIOImpl
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissImplConfig), f"Unexpected config type: {type(config)}"
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOImpl(config, deps[Api.inference])
impl = FaissVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import (
@json_schema_type
class FaissImplConfig(BaseModel):
class FaissVectorIOConfig(BaseModel):
kvstore: KVStoreConfig
@classmethod

View file

@ -8,11 +8,9 @@ import base64
import io
import json
import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
@ -26,7 +24,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex,
)
from .config import FaissImplConfig
from .config import FaissVectorIOConfig
logger = logging.getLogger(__name__)
@ -114,8 +112,8 @@ class FaissIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache = {}

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# config.py
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec.db",
)
}

Some files were not shown because too many files have changed in this diff Show more