mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'main' into fix_configure
This commit is contained in:
commit
a7250a1e33
39 changed files with 814 additions and 376 deletions
|
@ -51,3 +51,9 @@ repos:
|
||||||
# hooks:
|
# hooks:
|
||||||
# - id: pydoclint
|
# - id: pydoclint
|
||||||
# args: [--config=pyproject.toml]
|
# args: [--config=pyproject.toml]
|
||||||
|
|
||||||
|
# - repo: https://github.com/tcort/markdown-link-check
|
||||||
|
# rev: v3.11.2
|
||||||
|
# hooks:
|
||||||
|
# - id: markdown-link-check
|
||||||
|
# args: ['--quiet']
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
# Llama Stack
|
# Llama Stack
|
||||||
|
|
||||||
|
[](https://pypi.org/project/llama_stack/)
|
||||||
[](https://pypi.org/project/llama-stack/)
|
[](https://pypi.org/project/llama-stack/)
|
||||||
[](https://discord.gg/TZAAYNVtrU)
|
[](https://discord.gg/llama-stack)
|
||||||
|
|
||||||
This repository contains the Llama Stack API specifications as well as API Providers and Llama Stack Distributions.
|
This repository contains the Llama Stack API specifications as well as API Providers and Llama Stack Distributions.
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ The `llama` CLI tool helps you setup and use the Llama toolchain & agentic syste
|
||||||
### Subcommands
|
### Subcommands
|
||||||
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
|
||||||
2. `model`: Lists available models and their properties.
|
2. `model`: Lists available models and their properties.
|
||||||
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](/docs/cli_reference.md#step-3-building-configuring-and-running-llama-stack-servers).
|
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](cli_reference.md#step-3-building-and-configuring-llama-stack-distributions).
|
||||||
|
|
||||||
### Sample Usage
|
### Sample Usage
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.shields import * # noqa: F403
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
from llama_stack.apis.inspect import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
|
@ -63,6 +64,7 @@ class LlamaStack(
|
||||||
Evaluations,
|
Evaluations,
|
||||||
Models,
|
Models,
|
||||||
Shields,
|
Shields,
|
||||||
|
Inspect,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -1542,6 +1542,36 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/health": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HealthInfo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inspect"
|
||||||
|
],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "X-LlamaStack-ProviderData",
|
||||||
|
"in": "header",
|
||||||
|
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"/memory/insert": {
|
"/memory/insert": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -1665,6 +1695,75 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/providers/list": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"$ref": "#/components/schemas/ProviderInfo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inspect"
|
||||||
|
],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "X-LlamaStack-ProviderData",
|
||||||
|
"in": "header",
|
||||||
|
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/routes/list": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/RouteInfo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inspect"
|
||||||
|
],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "X-LlamaStack-ProviderData",
|
||||||
|
"in": "header",
|
||||||
|
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"/shields/list": {
|
"/shields/list": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -4783,7 +4882,7 @@
|
||||||
"provider_config": {
|
"provider_config": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"provider_id": {
|
"provider_type": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"config": {
|
"config": {
|
||||||
|
@ -4814,7 +4913,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"provider_id",
|
"provider_type",
|
||||||
"config"
|
"config"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -4843,7 +4942,7 @@
|
||||||
"provider_config": {
|
"provider_config": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"provider_id": {
|
"provider_type": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"config": {
|
"config": {
|
||||||
|
@ -4874,7 +4973,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"provider_id",
|
"provider_type",
|
||||||
"config"
|
"config"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -4894,7 +4993,7 @@
|
||||||
"provider_config": {
|
"provider_config": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"provider_id": {
|
"provider_type": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"config": {
|
"config": {
|
||||||
|
@ -4925,7 +5024,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"provider_id",
|
"provider_type",
|
||||||
"config"
|
"config"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -5086,6 +5185,18 @@
|
||||||
"job_uuid"
|
"job_uuid"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"HealthInfo": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"status"
|
||||||
|
]
|
||||||
|
},
|
||||||
"InsertDocumentsRequest": {
|
"InsertDocumentsRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -5108,6 +5219,45 @@
|
||||||
"documents"
|
"documents"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"ProviderInfo": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"provider_type": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"provider_type",
|
||||||
|
"description"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"RouteInfo": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"route": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"method": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"providers": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"route",
|
||||||
|
"method",
|
||||||
|
"providers"
|
||||||
|
]
|
||||||
|
},
|
||||||
"LogSeverity": {
|
"LogSeverity": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
|
@ -6220,19 +6370,34 @@
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
{
|
{
|
||||||
"name": "Shields"
|
"name": "Datasets"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Inspect"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Memory"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BatchInference"
|
"name": "BatchInference"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "RewardScoring"
|
"name": "Agents"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Inference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Shields"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "SyntheticDataGeneration"
|
"name": "SyntheticDataGeneration"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Agents"
|
"name": "Models"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "RewardScoring"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "MemoryBanks"
|
"name": "MemoryBanks"
|
||||||
|
@ -6241,13 +6406,7 @@
|
||||||
"name": "Safety"
|
"name": "Safety"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Models"
|
"name": "Evaluations"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Inference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Memory"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Telemetry"
|
"name": "Telemetry"
|
||||||
|
@ -6255,12 +6414,6 @@
|
||||||
{
|
{
|
||||||
"name": "PostTraining"
|
"name": "PostTraining"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "Datasets"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Evaluations"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "BuiltinTool",
|
"name": "BuiltinTool",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
|
||||||
|
@ -6653,10 +6806,22 @@
|
||||||
"name": "PostTrainingJob",
|
"name": "PostTrainingJob",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/PostTrainingJob\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/PostTrainingJob\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "HealthInfo",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "InsertDocumentsRequest",
|
"name": "InsertDocumentsRequest",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ProviderInfo",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ProviderInfo\" />"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "RouteInfo",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RouteInfo\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "LogSeverity",
|
"name": "LogSeverity",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LogSeverity\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LogSeverity\" />"
|
||||||
|
@ -6787,6 +6952,7 @@
|
||||||
"Datasets",
|
"Datasets",
|
||||||
"Evaluations",
|
"Evaluations",
|
||||||
"Inference",
|
"Inference",
|
||||||
|
"Inspect",
|
||||||
"Memory",
|
"Memory",
|
||||||
"MemoryBanks",
|
"MemoryBanks",
|
||||||
"Models",
|
"Models",
|
||||||
|
@ -6857,6 +7023,7 @@
|
||||||
"FunctionCallToolDefinition",
|
"FunctionCallToolDefinition",
|
||||||
"GetAgentsSessionRequest",
|
"GetAgentsSessionRequest",
|
||||||
"GetDocumentsRequest",
|
"GetDocumentsRequest",
|
||||||
|
"HealthInfo",
|
||||||
"ImageMedia",
|
"ImageMedia",
|
||||||
"InferenceStep",
|
"InferenceStep",
|
||||||
"InsertDocumentsRequest",
|
"InsertDocumentsRequest",
|
||||||
|
@ -6880,6 +7047,7 @@
|
||||||
"PostTrainingJobStatus",
|
"PostTrainingJobStatus",
|
||||||
"PostTrainingJobStatusResponse",
|
"PostTrainingJobStatusResponse",
|
||||||
"PreferenceOptimizeRequest",
|
"PreferenceOptimizeRequest",
|
||||||
|
"ProviderInfo",
|
||||||
"QLoraFinetuningConfig",
|
"QLoraFinetuningConfig",
|
||||||
"QueryDocumentsRequest",
|
"QueryDocumentsRequest",
|
||||||
"QueryDocumentsResponse",
|
"QueryDocumentsResponse",
|
||||||
|
@ -6888,6 +7056,7 @@
|
||||||
"RestAPIMethod",
|
"RestAPIMethod",
|
||||||
"RewardScoreRequest",
|
"RewardScoreRequest",
|
||||||
"RewardScoringResponse",
|
"RewardScoringResponse",
|
||||||
|
"RouteInfo",
|
||||||
"RunShieldRequest",
|
"RunShieldRequest",
|
||||||
"RunShieldResponse",
|
"RunShieldResponse",
|
||||||
"SafetyViolation",
|
"SafetyViolation",
|
||||||
|
|
|
@ -908,6 +908,14 @@ components:
|
||||||
required:
|
required:
|
||||||
- document_ids
|
- document_ids
|
||||||
type: object
|
type: object
|
||||||
|
HealthInfo:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- status
|
||||||
|
type: object
|
||||||
ImageMedia:
|
ImageMedia:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1117,10 +1125,10 @@ components:
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
type: object
|
type: object
|
||||||
provider_id:
|
provider_type:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- provider_id
|
- provider_type
|
||||||
- config
|
- config
|
||||||
type: object
|
type: object
|
||||||
required:
|
required:
|
||||||
|
@ -1362,10 +1370,10 @@ components:
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
type: object
|
type: object
|
||||||
provider_id:
|
provider_type:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- provider_id
|
- provider_type
|
||||||
- config
|
- config
|
||||||
type: object
|
type: object
|
||||||
required:
|
required:
|
||||||
|
@ -1543,6 +1551,17 @@ components:
|
||||||
- hyperparam_search_config
|
- hyperparam_search_config
|
||||||
- logger_config
|
- logger_config
|
||||||
type: object
|
type: object
|
||||||
|
ProviderInfo:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
description:
|
||||||
|
type: string
|
||||||
|
provider_type:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- provider_type
|
||||||
|
- description
|
||||||
|
type: object
|
||||||
QLoraFinetuningConfig:
|
QLoraFinetuningConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1704,6 +1723,22 @@ components:
|
||||||
title: Response from the reward scoring. Batch of (prompt, response, score)
|
title: Response from the reward scoring. Batch of (prompt, response, score)
|
||||||
tuples that pass the threshold.
|
tuples that pass the threshold.
|
||||||
type: object
|
type: object
|
||||||
|
RouteInfo:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
method:
|
||||||
|
type: string
|
||||||
|
providers:
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
type: array
|
||||||
|
route:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- route
|
||||||
|
- method
|
||||||
|
- providers
|
||||||
|
type: object
|
||||||
RunShieldRequest:
|
RunShieldRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1916,10 +1951,10 @@ components:
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
type: object
|
type: object
|
||||||
provider_id:
|
provider_type:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- provider_id
|
- provider_type
|
||||||
- config
|
- config
|
||||||
type: object
|
type: object
|
||||||
shield_type:
|
shield_type:
|
||||||
|
@ -2569,7 +2604,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
|
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -3093,6 +3128,25 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Evaluations
|
- Evaluations
|
||||||
|
/health:
|
||||||
|
get:
|
||||||
|
parameters:
|
||||||
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
adapter servicing the API
|
||||||
|
in: header
|
||||||
|
name: X-LlamaStack-ProviderData
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/HealthInfo'
|
||||||
|
description: OK
|
||||||
|
tags:
|
||||||
|
- Inspect
|
||||||
/inference/chat_completion:
|
/inference/chat_completion:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -3637,6 +3691,27 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- PostTraining
|
- PostTraining
|
||||||
|
/providers/list:
|
||||||
|
get:
|
||||||
|
parameters:
|
||||||
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
adapter servicing the API
|
||||||
|
in: header
|
||||||
|
name: X-LlamaStack-ProviderData
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
additionalProperties:
|
||||||
|
$ref: '#/components/schemas/ProviderInfo'
|
||||||
|
type: object
|
||||||
|
description: OK
|
||||||
|
tags:
|
||||||
|
- Inspect
|
||||||
/reward_scoring/score:
|
/reward_scoring/score:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -3662,6 +3737,29 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- RewardScoring
|
- RewardScoring
|
||||||
|
/routes/list:
|
||||||
|
get:
|
||||||
|
parameters:
|
||||||
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
adapter servicing the API
|
||||||
|
in: header
|
||||||
|
name: X-LlamaStack-ProviderData
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
additionalProperties:
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/RouteInfo'
|
||||||
|
type: array
|
||||||
|
type: object
|
||||||
|
description: OK
|
||||||
|
tags:
|
||||||
|
- Inspect
|
||||||
/safety/run_shield:
|
/safety/run_shield:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -3807,20 +3905,21 @@ security:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
tags:
|
tags:
|
||||||
- name: Shields
|
- name: Datasets
|
||||||
|
- name: Inspect
|
||||||
|
- name: Memory
|
||||||
- name: BatchInference
|
- name: BatchInference
|
||||||
- name: RewardScoring
|
|
||||||
- name: SyntheticDataGeneration
|
|
||||||
- name: Agents
|
- name: Agents
|
||||||
|
- name: Inference
|
||||||
|
- name: Shields
|
||||||
|
- name: SyntheticDataGeneration
|
||||||
|
- name: Models
|
||||||
|
- name: RewardScoring
|
||||||
- name: MemoryBanks
|
- name: MemoryBanks
|
||||||
- name: Safety
|
- name: Safety
|
||||||
- name: Models
|
- name: Evaluations
|
||||||
- name: Inference
|
|
||||||
- name: Memory
|
|
||||||
- name: Telemetry
|
- name: Telemetry
|
||||||
- name: PostTraining
|
- name: PostTraining
|
||||||
- name: Datasets
|
|
||||||
- name: Evaluations
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||||
name: BuiltinTool
|
name: BuiltinTool
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||||
|
@ -4135,9 +4234,15 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
|
||||||
/>
|
/>
|
||||||
name: PostTrainingJob
|
name: PostTrainingJob
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
|
||||||
|
name: HealthInfo
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest"
|
||||||
/>
|
/>
|
||||||
name: InsertDocumentsRequest
|
name: InsertDocumentsRequest
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/ProviderInfo" />
|
||||||
|
name: ProviderInfo
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
|
||||||
|
name: RouteInfo
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
|
||||||
name: LogSeverity
|
name: LogSeverity
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
|
||||||
|
@ -4236,6 +4341,7 @@ x-tagGroups:
|
||||||
- Datasets
|
- Datasets
|
||||||
- Evaluations
|
- Evaluations
|
||||||
- Inference
|
- Inference
|
||||||
|
- Inspect
|
||||||
- Memory
|
- Memory
|
||||||
- MemoryBanks
|
- MemoryBanks
|
||||||
- Models
|
- Models
|
||||||
|
@ -4303,6 +4409,7 @@ x-tagGroups:
|
||||||
- FunctionCallToolDefinition
|
- FunctionCallToolDefinition
|
||||||
- GetAgentsSessionRequest
|
- GetAgentsSessionRequest
|
||||||
- GetDocumentsRequest
|
- GetDocumentsRequest
|
||||||
|
- HealthInfo
|
||||||
- ImageMedia
|
- ImageMedia
|
||||||
- InferenceStep
|
- InferenceStep
|
||||||
- InsertDocumentsRequest
|
- InsertDocumentsRequest
|
||||||
|
@ -4326,6 +4433,7 @@ x-tagGroups:
|
||||||
- PostTrainingJobStatus
|
- PostTrainingJobStatus
|
||||||
- PostTrainingJobStatusResponse
|
- PostTrainingJobStatusResponse
|
||||||
- PreferenceOptimizeRequest
|
- PreferenceOptimizeRequest
|
||||||
|
- ProviderInfo
|
||||||
- QLoraFinetuningConfig
|
- QLoraFinetuningConfig
|
||||||
- QueryDocumentsRequest
|
- QueryDocumentsRequest
|
||||||
- QueryDocumentsResponse
|
- QueryDocumentsResponse
|
||||||
|
@ -4334,6 +4442,7 @@ x-tagGroups:
|
||||||
- RestAPIMethod
|
- RestAPIMethod
|
||||||
- RewardScoreRequest
|
- RewardScoreRequest
|
||||||
- RewardScoringResponse
|
- RewardScoringResponse
|
||||||
|
- RouteInfo
|
||||||
- RunShieldRequest
|
- RunShieldRequest
|
||||||
- RunShieldResponse
|
- RunShieldResponse
|
||||||
- SafetyViolation
|
- SafetyViolation
|
||||||
|
|
7
llama_stack/apis/inspect/__init__.py
Normal file
7
llama_stack/apis/inspect/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .inspect import * # noqa: F401 F403
|
82
llama_stack/apis/inspect/client.py
Normal file
82
llama_stack/apis/inspect/client.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .inspect import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class InspectClient(Inspect):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_providers(self) -> Dict[str, ProviderInfo]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/providers/list",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
print(response.json())
|
||||||
|
return {
|
||||||
|
k: [ProviderInfo(**vi) for vi in v] for k, v in response.json().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/routes/list",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return {
|
||||||
|
k: [RouteInfo(**vi) for vi in v] for k, v in response.json().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def health(self) -> HealthInfo:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/health",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
j = response.json()
|
||||||
|
if j is None:
|
||||||
|
return None
|
||||||
|
return HealthInfo(**j)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int):
|
||||||
|
client = InspectClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
response = await client.list_providers()
|
||||||
|
cprint(f"list_providers response={response}", "green")
|
||||||
|
|
||||||
|
response = await client.list_routes()
|
||||||
|
cprint(f"list_routes response={response}", "blue")
|
||||||
|
|
||||||
|
response = await client.health()
|
||||||
|
cprint(f"health response={response}", "yellow")
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int):
|
||||||
|
asyncio.run(run_main(host, port))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
40
llama_stack/apis/inspect/inspect.py
Normal file
40
llama_stack/apis/inspect/inspect.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict, List, Protocol
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ProviderInfo(BaseModel):
|
||||||
|
provider_type: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RouteInfo(BaseModel):
|
||||||
|
route: str
|
||||||
|
method: str
|
||||||
|
providers: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class HealthInfo(BaseModel):
|
||||||
|
status: str
|
||||||
|
# TODO: add a provider level status
|
||||||
|
|
||||||
|
|
||||||
|
class Inspect(Protocol):
|
||||||
|
@webmethod(route="/providers/list", method="GET")
|
||||||
|
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/routes/list", method="GET")
|
||||||
|
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/health", method="GET")
|
||||||
|
async def health(self) -> HealthInfo: ...
|
|
@ -18,7 +18,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||||
class MemoryBankSpec(BaseModel):
|
class MemoryBankSpec(BaseModel):
|
||||||
bank_type: MemoryBankType
|
bank_type: MemoryBankType
|
||||||
provider_config: GenericProviderConfig = Field(
|
provider_config: GenericProviderConfig = Field(
|
||||||
description="Provider config for the model, including provider_id, and corresponding config. ",
|
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class ModelServingSpec(BaseModel):
|
||||||
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
|
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
|
||||||
)
|
)
|
||||||
provider_config: GenericProviderConfig = Field(
|
provider_config: GenericProviderConfig = Field(
|
||||||
description="Provider config for the model, including provider_id, and corresponding config. ",
|
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||||
class ShieldSpec(BaseModel):
|
class ShieldSpec(BaseModel):
|
||||||
shield_type: str
|
shield_type: str
|
||||||
provider_config: GenericProviderConfig = Field(
|
provider_config: GenericProviderConfig = Field(
|
||||||
description="Provider config for the model, including provider_id, and corresponding config. ",
|
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
@ -110,7 +109,4 @@ def render_markdown_to_pager(markdown_content: str):
|
||||||
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
|
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
|
||||||
console.print(md)
|
console.print(md)
|
||||||
rendered_content = output.getvalue()
|
rendered_content = output.getvalue()
|
||||||
|
print(rendered_content)
|
||||||
# Pipe to pager
|
|
||||||
pager = subprocess.Popen(["less", "-R"], stdin=subprocess.PIPE)
|
|
||||||
pager.communicate(input=rendered_content.encode())
|
|
||||||
|
|
|
@ -179,12 +179,7 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
Api,
|
|
||||||
api_providers,
|
|
||||||
builtin_automatically_routed_apis,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -249,22 +244,12 @@ class StackBuild(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
|
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
all_providers = api_providers()
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
routing_table_apis = set(
|
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
|
||||||
)
|
|
||||||
|
|
||||||
for api in Api:
|
|
||||||
if api in routing_table_apis:
|
|
||||||
continue
|
|
||||||
|
|
||||||
providers_for_api = all_providers[api]
|
|
||||||
|
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for the {} API: (default=meta-reference): ".format(
|
"> Enter provider for the {} API: (default=meta-reference): ".format(
|
||||||
api.value
|
api.value
|
||||||
|
|
|
@ -34,9 +34,9 @@ class StackListProviders(Subcommand):
|
||||||
|
|
||||||
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.distribution.distribution import Api, api_providers
|
from llama_stack.distribution.distribution import Api, get_provider_registry
|
||||||
|
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
providers_for_api = all_providers[Api(args.api)]
|
providers_for_api = all_providers[Api(args.api)]
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
|
@ -47,11 +47,11 @@ class StackListProviders(Subcommand):
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for spec in providers_for_api.values():
|
for spec in providers_for_api.values():
|
||||||
if spec.provider_id == "sample":
|
if spec.provider_type == "sample":
|
||||||
continue
|
continue
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
spec.provider_id,
|
spec.provider_type,
|
||||||
",".join(spec.pip_packages),
|
",".join(spec.pip_packages),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,6 +19,17 @@ from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
|
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
|
||||||
|
|
||||||
|
# These are the dependencies needed by the distribution server.
|
||||||
|
# `llama-stack` is automatically installed by the installation script.
|
||||||
|
SERVER_DEPENDENCIES = [
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"uvicorn",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ImageType(Enum):
|
class ImageType(Enum):
|
||||||
|
@ -43,7 +54,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
)
|
)
|
||||||
|
|
||||||
# extend package dependencies based on providers spec
|
# extend package dependencies based on providers spec
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
for (
|
for (
|
||||||
api_str,
|
api_str,
|
||||||
provider_or_providers,
|
provider_or_providers,
|
||||||
|
|
|
@ -15,8 +15,8 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.memory.memory import MemoryBankType
|
from llama_stack.apis.memory.memory import MemoryBankType
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
api_providers,
|
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
get_provider_registry,
|
||||||
stack_apis,
|
stack_apis,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -62,7 +62,7 @@ def configure_api_providers(
|
||||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||||
|
|
||||||
apis = [v.value for v in stack_apis()]
|
apis = [v.value for v in stack_apis()]
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
|
|
||||||
# configure simple case for with non-routing providers to api_providers
|
# configure simple case for with non-routing providers to api_providers
|
||||||
for api_str in spec.providers.keys():
|
for api_str in spec.providers.keys():
|
||||||
|
@ -109,7 +109,7 @@ def configure_api_providers(
|
||||||
routing_entries.append(
|
routing_entries.append(
|
||||||
RoutableProviderConfig(
|
RoutableProviderConfig(
|
||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
provider_id=p,
|
provider_type=p,
|
||||||
config=cfg.dict(),
|
config=cfg.dict(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -120,7 +120,7 @@ def configure_api_providers(
|
||||||
routing_entries.append(
|
routing_entries.append(
|
||||||
RoutableProviderConfig(
|
RoutableProviderConfig(
|
||||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
routing_key=[s.value for s in MetaReferenceShieldType],
|
||||||
provider_id=p,
|
provider_type=p,
|
||||||
config=cfg.dict(),
|
config=cfg.dict(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -133,7 +133,7 @@ def configure_api_providers(
|
||||||
routing_entries.append(
|
routing_entries.append(
|
||||||
RoutableProviderConfig(
|
RoutableProviderConfig(
|
||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
provider_id=p,
|
provider_type=p,
|
||||||
config=cfg.dict(),
|
config=cfg.dict(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -153,7 +153,7 @@ def configure_api_providers(
|
||||||
routing_entries.append(
|
routing_entries.append(
|
||||||
RoutableProviderConfig(
|
RoutableProviderConfig(
|
||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
provider_id=p,
|
provider_type=p,
|
||||||
config=cfg.dict(),
|
config=cfg.dict(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -164,7 +164,7 @@ def configure_api_providers(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config.api_providers[api_str] = GenericProviderConfig(
|
config.api_providers[api_str] = GenericProviderConfig(
|
||||||
provider_id=p,
|
provider_type=p,
|
||||||
config=cfg.dict(),
|
config=cfg.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,53 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
RoutingKey = Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class GenericProviderConfig(BaseModel):
|
||||||
|
provider_type: str
|
||||||
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class RoutableProviderConfig(GenericProviderConfig):
|
||||||
|
routing_key: RoutingKey
|
||||||
|
|
||||||
|
|
||||||
|
class PlaceholderProviderConfig(BaseModel):
|
||||||
|
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||||
|
|
||||||
|
providers: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
# Example: /inference, /safety
|
||||||
|
class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
|
provider_type: str = "router"
|
||||||
|
config_class: str = ""
|
||||||
|
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
routing_table_api: Api
|
||||||
|
module: str
|
||||||
|
provider_data_validator: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> List[str]:
|
||||||
|
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
# Example: /models, /shields
|
||||||
|
@json_schema_type
|
||||||
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
|
provider_type: str = "routing_table"
|
||||||
|
config_class: str = ""
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
|
inner_specs: List[ProviderSpec]
|
||||||
|
module: str
|
||||||
|
pip_packages: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
description: Optional[str] = Field(
|
description: Optional[str] = Field(
|
||||||
|
@ -71,7 +118,7 @@ Provider configurations for each of the APIs provided by this package.
|
||||||
|
|
||||||
E.g. The following is a ProviderRoutingEntry for models:
|
E.g. The following is a ProviderRoutingEntry for models:
|
||||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
provider_id: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Meta-Llama3.1-8B-Instruct
|
model: Meta-Llama3.1-8B-Instruct
|
||||||
quantization: null
|
quantization: null
|
||||||
|
|
|
@ -5,30 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
||||||
from llama_stack.apis.inference import Inference
|
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.safety import Safety
|
|
||||||
from llama_stack.apis.shields import Shields
|
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
|
||||||
|
|
||||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
|
||||||
# `llama-stack` is automatically installed by the installation script.
|
|
||||||
SERVER_DEPENDENCIES = [
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"uvicorn",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
|
@ -57,58 +38,21 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def providable_apis() -> List[Api]:
|
||||||
apis = {}
|
|
||||||
|
|
||||||
protocols = {
|
|
||||||
Api.inference: Inference,
|
|
||||||
Api.safety: Safety,
|
|
||||||
Api.agents: Agents,
|
|
||||||
Api.memory: Memory,
|
|
||||||
Api.telemetry: Telemetry,
|
|
||||||
Api.models: Models,
|
|
||||||
Api.shields: Shields,
|
|
||||||
Api.memory_banks: MemoryBanks,
|
|
||||||
}
|
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
|
||||||
endpoints = []
|
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
|
||||||
|
|
||||||
for name, method in protocol_methods:
|
|
||||||
if not hasattr(method, "__webmethod__"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
webmethod = method.__webmethod__
|
|
||||||
route = webmethod.route
|
|
||||||
|
|
||||||
if webmethod.method == "GET":
|
|
||||||
method = "get"
|
|
||||||
elif webmethod.method == "DELETE":
|
|
||||||
method = "delete"
|
|
||||||
else:
|
|
||||||
method = "post"
|
|
||||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
|
||||||
|
|
||||||
apis[api] = endpoints
|
|
||||||
|
|
||||||
return apis
|
|
||||||
|
|
||||||
|
|
||||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|
||||||
ret = {}
|
|
||||||
routing_table_apis = set(
|
routing_table_apis = set(
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
)
|
)
|
||||||
for api in stack_apis():
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||||
if api in routing_table_apis:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
|
ret = {}
|
||||||
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
ret[api] = {
|
ret[api] = {
|
||||||
"remote": remote_provider_spec(api),
|
"remote": remote_provider_spec(api),
|
||||||
**{a.provider_id: a for a in module.available_providers()},
|
**{a.provider_type: a for a in module.available_providers()},
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
54
llama_stack/distribution/inspect.py
Normal file
54
llama_stack/distribution/inspect.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
from llama_stack.apis.inspect import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||||
|
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionInspectImpl(Inspect):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||||
|
ret = {}
|
||||||
|
all_providers = get_provider_registry()
|
||||||
|
for api, providers in all_providers.items():
|
||||||
|
ret[api.value] = [
|
||||||
|
ProviderInfo(
|
||||||
|
provider_type=p.provider_type,
|
||||||
|
description="Passthrough" if is_passthrough(p) else "",
|
||||||
|
)
|
||||||
|
for p in providers.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||||
|
ret = {}
|
||||||
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
|
for api, endpoints in all_endpoints.items():
|
||||||
|
ret[api.value] = [
|
||||||
|
RouteInfo(
|
||||||
|
route=e.route,
|
||||||
|
method=e.method,
|
||||||
|
providers=[],
|
||||||
|
)
|
||||||
|
for e in endpoints
|
||||||
|
]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def health(self) -> HealthInfo:
|
||||||
|
return HealthInfo(status="OK")
|
|
@ -18,10 +18,10 @@ class NeedsRequestProviderData:
|
||||||
spec = self.__provider_spec__
|
spec = self.__provider_spec__
|
||||||
assert spec, f"Provider spec not set on {self.__class__}"
|
assert spec, f"Provider spec not set on {self.__class__}"
|
||||||
|
|
||||||
provider_id = spec.provider_id
|
provider_type = spec.provider_type
|
||||||
validator_class = spec.provider_data_validator
|
validator_class = spec.provider_data_validator
|
||||||
if not validator_class:
|
if not validator_class:
|
||||||
raise ValueError(f"Provider {provider_id} does not have a validator")
|
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||||
|
|
||||||
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
||||||
if not val:
|
if not val:
|
||||||
|
|
|
@ -3,15 +3,17 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import importlib
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
api_providers,
|
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
from llama_stack.distribution.inspect import DistributionInspectImpl
|
||||||
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
|
@ -20,7 +22,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
specs = {}
|
specs = {}
|
||||||
configs = {}
|
configs = {}
|
||||||
|
|
||||||
|
@ -34,11 +36,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
if isinstance(config, PlaceholderProviderConfig):
|
if isinstance(config, PlaceholderProviderConfig):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if config.provider_id not in providers:
|
if config.provider_type not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
|
f"Provider `{config.provider_type}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
specs[api] = providers[config.provider_id]
|
specs[api] = providers[config.provider_type]
|
||||||
configs[api] = config
|
configs[api] = config
|
||||||
|
|
||||||
apis_to_serve = run_config.apis_to_serve or set(
|
apis_to_serve = run_config.apis_to_serve or set(
|
||||||
|
@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
if info.router_api.value not in apis_to_serve:
|
if info.router_api.value not in apis_to_serve:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("router_api", info.router_api)
|
|
||||||
if info.router_api.value not in run_config.routing_table:
|
if info.router_api.value not in run_config.routing_table:
|
||||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||||
|
|
||||||
|
@ -68,12 +69,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
inner_specs = []
|
inner_specs = []
|
||||||
inner_deps = []
|
inner_deps = []
|
||||||
for rt_entry in routing_table:
|
for rt_entry in routing_table:
|
||||||
if rt_entry.provider_id not in providers:
|
if rt_entry.provider_type not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
inner_specs.append(providers[rt_entry.provider_id])
|
inner_specs.append(providers[rt_entry.provider_type])
|
||||||
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
|
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
|
||||||
|
|
||||||
specs[source_api] = RoutingTableProviderSpec(
|
specs[source_api] = RoutingTableProviderSpec(
|
||||||
api=source_api,
|
api=source_api,
|
||||||
|
@ -94,7 +95,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||||
for spec in sorted_specs:
|
for spec in sorted_specs:
|
||||||
print(f" {spec.api}: {spec.provider_id}")
|
print(f" {spec.api}: {spec.provider_type}")
|
||||||
print("")
|
print("")
|
||||||
impls = {}
|
impls = {}
|
||||||
for spec in sorted_specs:
|
for spec in sorted_specs:
|
||||||
|
@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
|
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
|
impls[Api.inspect] = DistributionInspectImpl()
|
||||||
|
specs[Api.inspect] = InlineProviderSpec(
|
||||||
|
api=Api.inspect,
|
||||||
|
provider_type="__distribution_builtin__",
|
||||||
|
config_class="",
|
||||||
|
module="",
|
||||||
|
)
|
||||||
|
|
||||||
return impls, specs
|
return impls, specs
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
dfs(a, visited, stack)
|
dfs(a, visited, stack)
|
||||||
|
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
|
async def instantiate_provider(
|
||||||
|
provider_spec: ProviderSpec,
|
||||||
|
deps: Dict[str, Any],
|
||||||
|
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||||
|
):
|
||||||
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
|
args = []
|
||||||
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
|
if provider_spec.adapter:
|
||||||
|
method = "get_adapter_impl"
|
||||||
|
else:
|
||||||
|
method = "get_client_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, GenericProviderConfig)
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
config = config_type(**provider_config.config)
|
||||||
|
args = [config, deps]
|
||||||
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
|
config = None
|
||||||
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||||
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, List)
|
||||||
|
routing_table = provider_config
|
||||||
|
|
||||||
|
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
||||||
|
inner_impls = []
|
||||||
|
for routing_entry in routing_table:
|
||||||
|
impl = await instantiate_provider(
|
||||||
|
inner_specs[routing_entry.provider_type],
|
||||||
|
deps,
|
||||||
|
routing_entry,
|
||||||
|
)
|
||||||
|
inner_impls.append((routing_entry.routing_key, impl))
|
||||||
|
|
||||||
|
config = None
|
||||||
|
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||||
|
else:
|
||||||
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, GenericProviderConfig)
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
config = config_type(**provider_config.config)
|
||||||
|
args = [config, deps]
|
||||||
|
|
||||||
|
fn = getattr(module, method)
|
||||||
|
impl = await fn(*args)
|
||||||
|
impl.__provider_spec__ = provider_spec
|
||||||
|
impl.__provider_config__ = config
|
||||||
|
return impl
|
||||||
|
|
|
@ -94,12 +94,21 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> List[ShieldSpec]:
|
async def list_shields(self) -> List[ShieldSpec]:
|
||||||
specs = []
|
specs = []
|
||||||
for entry in self.routing_table_config:
|
for entry in self.routing_table_config:
|
||||||
specs.append(
|
if isinstance(entry.routing_key, list):
|
||||||
ShieldSpec(
|
for k in entry.routing_key:
|
||||||
shield_type=entry.routing_key,
|
specs.append(
|
||||||
provider_config=entry,
|
ShieldSpec(
|
||||||
|
shield_type=k,
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
specs.append(
|
||||||
|
ShieldSpec(
|
||||||
|
shield_type=entry.routing_key,
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||||
|
|
67
llama_stack/distribution/server/endpoints.py
Normal file
67
llama_stack/distribution/server/endpoints.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.inspect import Inspect
|
||||||
|
from llama_stack.apis.memory import Memory
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.shields import Shields
|
||||||
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
class ApiEndpoint(BaseModel):
|
||||||
|
route: str
|
||||||
|
method: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
|
apis = {}
|
||||||
|
|
||||||
|
protocols = {
|
||||||
|
Api.inference: Inference,
|
||||||
|
Api.safety: Safety,
|
||||||
|
Api.agents: Agents,
|
||||||
|
Api.memory: Memory,
|
||||||
|
Api.telemetry: Telemetry,
|
||||||
|
Api.models: Models,
|
||||||
|
Api.shields: Shields,
|
||||||
|
Api.memory_banks: MemoryBanks,
|
||||||
|
Api.inspect: Inspect,
|
||||||
|
}
|
||||||
|
|
||||||
|
for api, protocol in protocols.items():
|
||||||
|
endpoints = []
|
||||||
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
|
for name, method in protocol_methods:
|
||||||
|
if not hasattr(method, "__webmethod__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
webmethod = method.__webmethod__
|
||||||
|
route = webmethod.route
|
||||||
|
|
||||||
|
if webmethod.method == "GET":
|
||||||
|
method = "get"
|
||||||
|
elif webmethod.method == "DELETE":
|
||||||
|
method = "delete"
|
||||||
|
else:
|
||||||
|
method = "post"
|
||||||
|
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||||
|
|
||||||
|
apis[api] = endpoints
|
||||||
|
|
||||||
|
return apis
|
|
@ -15,7 +15,6 @@ from collections.abc import (
|
||||||
AsyncIterator as AsyncIteratorABC,
|
AsyncIterator as AsyncIteratorABC,
|
||||||
)
|
)
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from http import HTTPStatus
|
|
||||||
from ssl import SSLError
|
from ssl import SSLError
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||||
|
|
||||||
|
@ -26,7 +25,6 @@ import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -39,10 +37,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_endpoints
|
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
|
|
||||||
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
if hasattr(typ, "__origin__"):
|
if hasattr(typ, "__origin__"):
|
||||||
|
@ -286,26 +285,18 @@ def main(
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Health check is added to enable deploying the docker container image on Kubernetes which require
|
|
||||||
# a health check that can return 200 for readiness and liveness check
|
|
||||||
class HealthCheck(BaseModel):
|
|
||||||
status: str = "OK"
|
|
||||||
|
|
||||||
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
|
|
||||||
async def healthcheck():
|
|
||||||
return HealthCheck(status="OK")
|
|
||||||
|
|
||||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
if config.apis_to_serve:
|
if config.apis_to_serve:
|
||||||
apis_to_serve = set(config.apis_to_serve)
|
apis_to_serve = set(config.apis_to_serve)
|
||||||
else:
|
else:
|
||||||
apis_to_serve = set(impls.keys())
|
apis_to_serve = set(impls.keys())
|
||||||
|
|
||||||
|
apis_to_serve.add(Api.inspect)
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
|
@ -339,14 +330,11 @@ def main(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for route in app.routes:
|
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||||
if isinstance(route, APIRoute):
|
for endpoint in endpoints:
|
||||||
cprint(
|
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
|
||||||
f"Serving {next(iter(route.methods))} {route.path}",
|
|
||||||
"white",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
print("")
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
|
@ -18,7 +18,7 @@ api_providers:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
agents:
|
agents:
|
||||||
provider_id: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
|
@ -28,22 +28,22 @@ api_providers:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
provider_id: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_table:
|
routing_table:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: remote::ollama
|
- provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 6000
|
port: 6000
|
||||||
routing_key: Meta-Llama3.1-8B-Instruct
|
routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
llama_guard_shield: null
|
llama_guard_shield: null
|
||||||
prompt_guard_shield: null
|
prompt_guard_shield: null
|
||||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_key: vector
|
routing_key: vector
|
||||||
|
|
|
@ -18,7 +18,7 @@ api_providers:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
agents:
|
agents:
|
||||||
provider_id: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
|
@ -28,11 +28,11 @@ api_providers:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
provider_id: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_table:
|
routing_table:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Llama3.1-8B-Instruct
|
model: Llama3.1-8B-Instruct
|
||||||
quantization: null
|
quantization: null
|
||||||
|
@ -41,12 +41,12 @@ routing_table:
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
routing_key: Llama3.1-8B-Instruct
|
routing_key: Llama3.1-8B-Instruct
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
llama_guard_shield: null
|
llama_guard_shield: null
|
||||||
prompt_guard_shield: null
|
prompt_guard_shield: null
|
||||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_key: vector
|
routing_key: vector
|
||||||
|
|
|
@ -5,69 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
|
||||||
async def instantiate_provider(
|
|
||||||
provider_spec: ProviderSpec,
|
|
||||||
deps: Dict[str, Any],
|
|
||||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
|
||||||
):
|
|
||||||
module = importlib.import_module(provider_spec.module)
|
|
||||||
|
|
||||||
args = []
|
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
|
||||||
if provider_spec.adapter:
|
|
||||||
method = "get_adapter_impl"
|
|
||||||
else:
|
|
||||||
method = "get_client_impl"
|
|
||||||
|
|
||||||
assert isinstance(provider_config, GenericProviderConfig)
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
|
||||||
config = config_type(**provider_config.config)
|
|
||||||
args = [config, deps]
|
|
||||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
|
||||||
method = "get_auto_router_impl"
|
|
||||||
|
|
||||||
config = None
|
|
||||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
|
||||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
|
||||||
method = "get_routing_table_impl"
|
|
||||||
|
|
||||||
assert isinstance(provider_config, List)
|
|
||||||
routing_table = provider_config
|
|
||||||
|
|
||||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
|
||||||
inner_impls = []
|
|
||||||
for routing_entry in routing_table:
|
|
||||||
impl = await instantiate_provider(
|
|
||||||
inner_specs[routing_entry.provider_id],
|
|
||||||
deps,
|
|
||||||
routing_entry,
|
|
||||||
)
|
|
||||||
inner_impls.append((routing_entry.routing_key, impl))
|
|
||||||
|
|
||||||
config = None
|
|
||||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
|
||||||
else:
|
|
||||||
method = "get_provider_impl"
|
|
||||||
|
|
||||||
assert isinstance(provider_config, GenericProviderConfig)
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
|
||||||
config = config_type(**provider_config.config)
|
|
||||||
args = [config, deps]
|
|
||||||
|
|
||||||
fn = getattr(module, method)
|
|
||||||
impl = await fn(*args)
|
|
||||||
impl.__provider_spec__ = provider_spec
|
|
||||||
impl.__provider_config__ = config
|
|
||||||
return impl
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -24,18 +24,14 @@ class Api(Enum):
|
||||||
shields = "shields"
|
shields = "shields"
|
||||||
memory_banks = "memory_banks"
|
memory_banks = "memory_banks"
|
||||||
|
|
||||||
|
# built-in API
|
||||||
@json_schema_type
|
inspect = "inspect"
|
||||||
class ApiEndpoint(BaseModel):
|
|
||||||
route: str
|
|
||||||
method: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderSpec(BaseModel):
|
class ProviderSpec(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
provider_id: str
|
provider_type: str
|
||||||
config_class: str = Field(
|
config_class: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Fully-qualified classname of the config for this provider",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
|
@ -62,71 +58,9 @@ class RoutableProvider(Protocol):
|
||||||
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class GenericProviderConfig(BaseModel):
|
|
||||||
provider_id: str
|
|
||||||
config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderProviderConfig(BaseModel):
|
|
||||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
|
||||||
|
|
||||||
providers: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
RoutingKey = Union[str, List[str]]
|
|
||||||
|
|
||||||
|
|
||||||
class RoutableProviderConfig(GenericProviderConfig):
|
|
||||||
routing_key: RoutingKey
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /inference, /safety
|
|
||||||
@json_schema_type
|
|
||||||
class AutoRoutedProviderSpec(ProviderSpec):
|
|
||||||
provider_id: str = "router"
|
|
||||||
config_class: str = ""
|
|
||||||
|
|
||||||
docker_image: Optional[str] = None
|
|
||||||
routing_table_api: Api
|
|
||||||
module: str = Field(
|
|
||||||
...,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
|
|
||||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
provider_data_validator: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pip_packages(self) -> List[str]:
|
|
||||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /models, /shields
|
|
||||||
@json_schema_type
|
|
||||||
class RoutingTableProviderSpec(ProviderSpec):
|
|
||||||
provider_id: str = "routing_table"
|
|
||||||
config_class: str = ""
|
|
||||||
docker_image: Optional[str] = None
|
|
||||||
|
|
||||||
inner_specs: List[ProviderSpec]
|
|
||||||
module: str = Field(
|
|
||||||
...,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
|
|
||||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
pip_packages: List[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_id: str = Field(
|
adapter_type: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Unique identifier for this adapter",
|
description="Unique identifier for this adapter",
|
||||||
)
|
)
|
||||||
|
@ -186,10 +120,6 @@ class RemoteProviderConfig(BaseModel):
|
||||||
return f"http://{self.host}:{self.port}"
|
return f"http://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_id(adapter_id: str) -> str:
|
|
||||||
return f"remote::{adapter_id}"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: Optional[AdapterSpec] = Field(
|
adapter: Optional[AdapterSpec] = Field(
|
||||||
|
@ -233,8 +163,8 @@ def remote_provider_spec(
|
||||||
if adapter and adapter.config_class
|
if adapter and adapter.config_class
|
||||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||||
)
|
)
|
||||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||||
|
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||||
)
|
)
|
||||||
|
|
|
@ -50,20 +50,6 @@ class LlamaGuardShieldConfig(BaseModel):
|
||||||
class PromptGuardShieldConfig(BaseModel):
|
class PromptGuardShieldConfig(BaseModel):
|
||||||
model: str = "Prompt-Guard-86M"
|
model: str = "Prompt-Guard-86M"
|
||||||
|
|
||||||
@validator("model")
|
|
||||||
@classmethod
|
|
||||||
def validate_model(cls, model: str) -> str:
|
|
||||||
permitted_models = [
|
|
||||||
m.descriptor()
|
|
||||||
for m in safety_models()
|
|
||||||
if m.core_model_id == CoreModelId.prompt_guard_86m
|
|
||||||
]
|
|
||||||
if model not in permitted_models:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid model: {model}. Must be one of {permitted_models}"
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyConfig(BaseModel):
|
class SafetyConfig(BaseModel):
|
||||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||||
|
|
|
@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.agents,
|
api=Api.agents,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"pillow",
|
"pillow",
|
||||||
|
@ -33,7 +33,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.agents,
|
api=Api.agents,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="sample",
|
adapter_type="sample",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.adapters.agents.sample",
|
module="llama_stack.providers.adapters.agents.sample",
|
||||||
config_class="llama_stack.providers.adapters.agents.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.agents.sample.SampleConfig",
|
||||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
|
@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="sample",
|
adapter_type="sample",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.adapters.inference.sample",
|
module="llama_stack.providers.adapters.inference.sample",
|
||||||
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
|
||||||
|
@ -39,7 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="ollama",
|
adapter_type="ollama",
|
||||||
pip_packages=["ollama"],
|
pip_packages=["ollama"],
|
||||||
module="llama_stack.providers.adapters.inference.ollama",
|
module="llama_stack.providers.adapters.inference.ollama",
|
||||||
),
|
),
|
||||||
|
@ -47,7 +47,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="tgi",
|
adapter_type="tgi",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.adapters.inference.tgi",
|
module="llama_stack.providers.adapters.inference.tgi",
|
||||||
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
|
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
|
||||||
|
@ -56,7 +56,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="hf::serverless",
|
adapter_type="hf::serverless",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.adapters.inference.tgi",
|
module="llama_stack.providers.adapters.inference.tgi",
|
||||||
config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig",
|
config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig",
|
||||||
|
@ -65,7 +65,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="hf::endpoint",
|
adapter_type="hf::endpoint",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.adapters.inference.tgi",
|
module="llama_stack.providers.adapters.inference.tgi",
|
||||||
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
|
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
|
||||||
|
@ -74,7 +74,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="fireworks",
|
adapter_type="fireworks",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"fireworks-ai",
|
"fireworks-ai",
|
||||||
],
|
],
|
||||||
|
@ -85,7 +85,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="together",
|
adapter_type="together",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"together",
|
"together",
|
||||||
],
|
],
|
||||||
|
@ -97,10 +97,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="bedrock",
|
adapter_type="bedrock",
|
||||||
pip_packages=[
|
pip_packages=["boto3"],
|
||||||
"boto3"
|
|
||||||
],
|
|
||||||
module="llama_stack.providers.adapters.inference.bedrock",
|
module="llama_stack.providers.adapters.inference.bedrock",
|
||||||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||||
),
|
),
|
||||||
|
|
|
@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.impls.meta_reference.memory",
|
module="llama_stack.providers.impls.meta_reference.memory",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig",
|
config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig",
|
||||||
|
@ -42,7 +42,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_id="chromadb",
|
adapter_type="chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||||
module="llama_stack.providers.adapters.memory.chroma",
|
module="llama_stack.providers.adapters.memory.chroma",
|
||||||
),
|
),
|
||||||
|
@ -50,7 +50,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_id="pgvector",
|
adapter_type="pgvector",
|
||||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||||
module="llama_stack.providers.adapters.memory.pgvector",
|
module="llama_stack.providers.adapters.memory.pgvector",
|
||||||
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
||||||
|
@ -59,7 +59,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="sample",
|
adapter_type="sample",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.adapters.memory.sample",
|
module="llama_stack.providers.adapters.memory.sample",
|
||||||
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
|
||||||
|
|
|
@ -19,7 +19,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
"transformers",
|
"transformers",
|
||||||
|
@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="sample",
|
adapter_type="sample",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.adapters.safety.sample",
|
module="llama_stack.providers.adapters.safety.sample",
|
||||||
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
||||||
|
@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="bedrock",
|
adapter_type="bedrock",
|
||||||
pip_packages=["boto3"],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.adapters.safety.bedrock",
|
module="llama_stack.providers.adapters.safety.bedrock",
|
||||||
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
|
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
|
||||||
|
@ -52,7 +52,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="together",
|
adapter_type="together",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"together",
|
"together",
|
||||||
],
|
],
|
||||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.telemetry,
|
api=Api.telemetry,
|
||||||
provider_id="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.impls.meta_reference.telemetry",
|
module="llama_stack.providers.impls.meta_reference.telemetry",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig",
|
config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig",
|
||||||
|
@ -21,7 +21,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.telemetry,
|
api=Api.telemetry,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="sample",
|
adapter_type="sample",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.adapters.telemetry.sample",
|
module="llama_stack.providers.adapters.telemetry.sample",
|
||||||
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
|
||||||
|
@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.telemetry,
|
api=Api.telemetry,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="opentelemetry-jaeger",
|
adapter_type="opentelemetry-jaeger",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"opentelemetry-api",
|
"opentelemetry-api",
|
||||||
"opentelemetry-sdk",
|
"opentelemetry-sdk",
|
||||||
|
|
|
@ -34,7 +34,8 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
if model.model_family == ModelFamily.llama3_1 or (
|
if model.model_family == ModelFamily.llama3_1 or (
|
||||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
model.model_family == ModelFamily.llama3_2
|
||||||
|
and is_multimodal(model.core_model_id)
|
||||||
):
|
):
|
||||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||||
return augment_messages_for_tools_llama_3_1(request)
|
return augment_messages_for_tools_llama_3_1(request)
|
||||||
|
|
|
@ -2,7 +2,7 @@ blobfile
|
||||||
fire
|
fire
|
||||||
httpx
|
httpx
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
llama-models>=0.0.37
|
llama-models>=0.0.38
|
||||||
prompt-toolkit
|
prompt-toolkit
|
||||||
python-dotenv
|
python-dotenv
|
||||||
pydantic>=2
|
pydantic>=2
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ def read_requirements():
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="llama_stack",
|
name="llama_stack",
|
||||||
version="0.0.37",
|
version="0.0.38",
|
||||||
author="Meta Llama",
|
author="Meta Llama",
|
||||||
author_email="llama-oss@meta.com",
|
author_email="llama-oss@meta.com",
|
||||||
description="Llama Stack",
|
description="Llama Stack",
|
||||||
|
|
|
@ -18,7 +18,7 @@ api_providers:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
agents:
|
agents:
|
||||||
provider_id: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
|
@ -28,11 +28,11 @@ api_providers:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
provider_id: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_table:
|
routing_table:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Meta-Llama3.1-8B-Instruct
|
model: Meta-Llama3.1-8B-Instruct
|
||||||
quantization: null
|
quantization: null
|
||||||
|
@ -41,7 +41,7 @@ routing_table:
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
routing_key: Meta-Llama3.1-8B-Instruct
|
routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
llama_guard_shield:
|
llama_guard_shield:
|
||||||
model: Llama-Guard-3-1B
|
model: Llama-Guard-3-1B
|
||||||
|
@ -52,6 +52,6 @@ routing_table:
|
||||||
model: Prompt-Guard-86M
|
model: Prompt-Guard-86M
|
||||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta-reference
|
- provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
routing_key: vector
|
routing_key: vector
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue