mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Merge branch 'main' into docs_improvement
This commit is contained in:
commit
1422d631a8
15 changed files with 554 additions and 340 deletions
|
@ -315,7 +315,20 @@ def get_endpoint_operations(
|
|||
)
|
||||
else:
|
||||
event_type = None
|
||||
response_type = return_type
|
||||
|
||||
def process_type(t):
|
||||
if typing.get_origin(t) is collections.abc.AsyncIterator:
|
||||
# NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List
|
||||
# or the item type. I am choosing it to be the latter
|
||||
args = typing.get_args(t)
|
||||
return args[0]
|
||||
elif typing.get_origin(t) is typing.Union:
|
||||
types = [process_type(a) for a in typing.get_args(t)]
|
||||
return typing._UnionGenericAlias(typing.Union, tuple(types))
|
||||
else:
|
||||
return t
|
||||
|
||||
response_type = process_type(return_type)
|
||||
|
||||
# set HTTP request method based on type of request and presence of payload
|
||||
if not request_params:
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
"info": {
|
||||
"title": "[DRAFT] Llama Stack Specification",
|
||||
"version": "0.0.1",
|
||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117"
|
||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
|
@ -320,11 +320,18 @@
|
|||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.",
|
||||
"content": {
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/Turn"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -934,7 +941,7 @@
|
|||
"schema": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ScoringFunctionDefWithProvider"
|
||||
"$ref": "#/components/schemas/ScoringFnDefWithProvider"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
|
@ -1555,7 +1562,7 @@
|
|||
"content": {
|
||||
"application/jsonl": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ScoringFunctionDefWithProvider"
|
||||
"$ref": "#/components/schemas/ScoringFnDefWithProvider"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2762,7 +2769,7 @@
|
|||
"const": "json_schema",
|
||||
"default": "json_schema"
|
||||
},
|
||||
"schema": {
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
|
@ -2791,7 +2798,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"schema"
|
||||
"json_schema"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -3018,7 +3025,7 @@
|
|||
"const": "json_schema",
|
||||
"default": "json_schema"
|
||||
},
|
||||
"schema": {
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
|
@ -3047,7 +3054,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"schema"
|
||||
"json_schema"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -4002,7 +4009,8 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"event"
|
||||
]
|
||||
],
|
||||
"title": "streamed agent turn completion response."
|
||||
},
|
||||
"AgentTurnResponseTurnCompletePayload": {
|
||||
"type": "object",
|
||||
|
@ -5004,24 +5012,6 @@
|
|||
"type"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "custom",
|
||||
"default": "custom"
|
||||
},
|
||||
"validator_class": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"validator_class"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -5304,24 +5294,6 @@
|
|||
"type"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "custom",
|
||||
"default": "custom"
|
||||
},
|
||||
"validator_class": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"validator_class"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -5376,7 +5348,7 @@
|
|||
"type"
|
||||
]
|
||||
},
|
||||
"ScoringFunctionDefWithProvider": {
|
||||
"ScoringFnDefWithProvider": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"identifier": {
|
||||
|
@ -5516,24 +5488,6 @@
|
|||
"type"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "custom",
|
||||
"default": "custom"
|
||||
},
|
||||
"validator_class": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"validator_class"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -5586,6 +5540,12 @@
|
|||
},
|
||||
"prompt_template": {
|
||||
"type": "string"
|
||||
},
|
||||
"judge_score_regex": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -6339,10 +6299,10 @@
|
|||
"finetuned_model": {
|
||||
"$ref": "#/components/schemas/URL"
|
||||
},
|
||||
"dataset": {
|
||||
"dataset_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"validation_dataset": {
|
||||
"validation_dataset_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"algorithm": {
|
||||
|
@ -6412,8 +6372,8 @@
|
|||
"required": [
|
||||
"job_uuid",
|
||||
"finetuned_model",
|
||||
"dataset",
|
||||
"validation_dataset",
|
||||
"dataset_id",
|
||||
"validation_dataset_id",
|
||||
"algorithm",
|
||||
"algorithm_config",
|
||||
"optimizer_config",
|
||||
|
@ -6595,7 +6555,7 @@
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"function_def": {
|
||||
"$ref": "#/components/schemas/ScoringFunctionDefWithProvider"
|
||||
"$ref": "#/components/schemas/ScoringFnDefWithProvider"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -6893,10 +6853,10 @@
|
|||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"dataset": {
|
||||
"dataset_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"validation_dataset": {
|
||||
"validation_dataset_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"algorithm": {
|
||||
|
@ -6976,8 +6936,8 @@
|
|||
"required": [
|
||||
"job_uuid",
|
||||
"model",
|
||||
"dataset",
|
||||
"validation_dataset",
|
||||
"dataset_id",
|
||||
"validation_dataset_id",
|
||||
"algorithm",
|
||||
"algorithm_config",
|
||||
"optimizer_config",
|
||||
|
@ -7102,57 +7062,57 @@
|
|||
}
|
||||
],
|
||||
"tags": [
|
||||
{
|
||||
"name": "Eval"
|
||||
},
|
||||
{
|
||||
"name": "ScoringFunctions"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGeneration"
|
||||
},
|
||||
{
|
||||
"name": "Inspect"
|
||||
},
|
||||
{
|
||||
"name": "PostTraining"
|
||||
},
|
||||
{
|
||||
"name": "Models"
|
||||
},
|
||||
{
|
||||
"name": "Safety"
|
||||
},
|
||||
{
|
||||
"name": "MemoryBanks"
|
||||
},
|
||||
{
|
||||
"name": "DatasetIO"
|
||||
},
|
||||
{
|
||||
"name": "Memory"
|
||||
},
|
||||
{
|
||||
"name": "Scoring"
|
||||
},
|
||||
{
|
||||
"name": "Shields"
|
||||
},
|
||||
{
|
||||
"name": "Datasets"
|
||||
},
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "Telemetry"
|
||||
"name": "Eval"
|
||||
},
|
||||
{
|
||||
"name": "MemoryBanks"
|
||||
},
|
||||
{
|
||||
"name": "Models"
|
||||
},
|
||||
{
|
||||
"name": "BatchInference"
|
||||
},
|
||||
{
|
||||
"name": "PostTraining"
|
||||
},
|
||||
{
|
||||
"name": "Agents"
|
||||
},
|
||||
{
|
||||
"name": "Shields"
|
||||
},
|
||||
{
|
||||
"name": "Telemetry"
|
||||
},
|
||||
{
|
||||
"name": "Inspect"
|
||||
},
|
||||
{
|
||||
"name": "DatasetIO"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGeneration"
|
||||
},
|
||||
{
|
||||
"name": "Datasets"
|
||||
},
|
||||
{
|
||||
"name": "Scoring"
|
||||
},
|
||||
{
|
||||
"name": "ScoringFunctions"
|
||||
},
|
||||
{
|
||||
"name": "Safety"
|
||||
},
|
||||
{
|
||||
"name": "BuiltinTool",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
|
||||
|
@ -7355,7 +7315,7 @@
|
|||
},
|
||||
{
|
||||
"name": "AgentTurnResponseStreamChunk",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AgentTurnResponseStreamChunk\" />"
|
||||
"description": "streamed agent turn completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/AgentTurnResponseStreamChunk\" />"
|
||||
},
|
||||
{
|
||||
"name": "AgentTurnResponseTurnCompletePayload",
|
||||
|
@ -7486,8 +7446,8 @@
|
|||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Parameter\" />"
|
||||
},
|
||||
{
|
||||
"name": "ScoringFunctionDefWithProvider",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ScoringFunctionDefWithProvider\" />"
|
||||
"name": "ScoringFnDefWithProvider",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ScoringFnDefWithProvider\" />"
|
||||
},
|
||||
{
|
||||
"name": "ShieldDefWithProvider",
|
||||
|
@ -7805,7 +7765,7 @@
|
|||
"ScoreBatchResponse",
|
||||
"ScoreRequest",
|
||||
"ScoreResponse",
|
||||
"ScoringFunctionDefWithProvider",
|
||||
"ScoringFnDefWithProvider",
|
||||
"ScoringResult",
|
||||
"SearchToolDefinition",
|
||||
"Session",
|
||||
|
|
|
@ -190,6 +190,7 @@ components:
|
|||
$ref: '#/components/schemas/AgentTurnResponseEvent'
|
||||
required:
|
||||
- event
|
||||
title: streamed agent turn completion response.
|
||||
type: object
|
||||
AgentTurnResponseTurnCompletePayload:
|
||||
additionalProperties: false
|
||||
|
@ -360,7 +361,7 @@ components:
|
|||
oneOf:
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
schema:
|
||||
json_schema:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
|
@ -376,7 +377,7 @@ components:
|
|||
type: string
|
||||
required:
|
||||
- type
|
||||
- schema
|
||||
- json_schema
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
|
@ -541,7 +542,7 @@ components:
|
|||
oneOf:
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
schema:
|
||||
json_schema:
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
|
@ -557,7 +558,7 @@ components:
|
|||
type: string
|
||||
required:
|
||||
- type
|
||||
- schema
|
||||
- json_schema
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
|
@ -747,18 +748,6 @@ components:
|
|||
required:
|
||||
- type
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: custom
|
||||
default: custom
|
||||
type: string
|
||||
validator_class:
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- validator_class
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
|
@ -1575,18 +1564,6 @@ components:
|
|||
required:
|
||||
- type
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: custom
|
||||
default: custom
|
||||
type: string
|
||||
validator_class:
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- validator_class
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
|
@ -1724,7 +1701,7 @@ components:
|
|||
$ref: '#/components/schemas/RLHFAlgorithm'
|
||||
algorithm_config:
|
||||
$ref: '#/components/schemas/DPOAlignmentConfig'
|
||||
dataset:
|
||||
dataset_id:
|
||||
type: string
|
||||
finetuned_model:
|
||||
$ref: '#/components/schemas/URL'
|
||||
|
@ -1754,13 +1731,13 @@ components:
|
|||
$ref: '#/components/schemas/OptimizerConfig'
|
||||
training_config:
|
||||
$ref: '#/components/schemas/TrainingConfig'
|
||||
validation_dataset:
|
||||
validation_dataset_id:
|
||||
type: string
|
||||
required:
|
||||
- job_uuid
|
||||
- finetuned_model
|
||||
- dataset
|
||||
- validation_dataset
|
||||
- dataset_id
|
||||
- validation_dataset_id
|
||||
- algorithm
|
||||
- algorithm_config
|
||||
- optimizer_config
|
||||
|
@ -1899,7 +1876,7 @@ components:
|
|||
additionalProperties: false
|
||||
properties:
|
||||
function_def:
|
||||
$ref: '#/components/schemas/ScoringFunctionDefWithProvider'
|
||||
$ref: '#/components/schemas/ScoringFnDefWithProvider'
|
||||
required:
|
||||
- function_def
|
||||
type: object
|
||||
|
@ -2121,7 +2098,7 @@ components:
|
|||
required:
|
||||
- results
|
||||
type: object
|
||||
ScoringFunctionDefWithProvider:
|
||||
ScoringFnDefWithProvider:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
context:
|
||||
|
@ -2129,6 +2106,10 @@ components:
|
|||
properties:
|
||||
judge_model:
|
||||
type: string
|
||||
judge_score_regex:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
prompt_template:
|
||||
type: string
|
||||
required:
|
||||
|
@ -2219,18 +2200,6 @@ components:
|
|||
required:
|
||||
- type
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
const: custom
|
||||
default: custom
|
||||
type: string
|
||||
validator_class:
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- validator_class
|
||||
type: object
|
||||
- additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
|
@ -2484,7 +2453,7 @@ components:
|
|||
- $ref: '#/components/schemas/LoraFinetuningConfig'
|
||||
- $ref: '#/components/schemas/QLoraFinetuningConfig'
|
||||
- $ref: '#/components/schemas/DoraFinetuningConfig'
|
||||
dataset:
|
||||
dataset_id:
|
||||
type: string
|
||||
hyperparam_search_config:
|
||||
additionalProperties:
|
||||
|
@ -2514,13 +2483,13 @@ components:
|
|||
$ref: '#/components/schemas/OptimizerConfig'
|
||||
training_config:
|
||||
$ref: '#/components/schemas/TrainingConfig'
|
||||
validation_dataset:
|
||||
validation_dataset_id:
|
||||
type: string
|
||||
required:
|
||||
- job_uuid
|
||||
- model
|
||||
- dataset
|
||||
- validation_dataset
|
||||
- dataset_id
|
||||
- validation_dataset_id
|
||||
- algorithm
|
||||
- algorithm_config
|
||||
- optimizer_config
|
||||
|
@ -3029,7 +2998,7 @@ info:
|
|||
description: "This is the specification of the llama stack that provides\n \
|
||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||
\ to\n best leverage Llama Models. The specification is still in\
|
||||
\ draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117"
|
||||
\ draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
|
||||
title: '[DRAFT] Llama Stack Specification'
|
||||
version: 0.0.1
|
||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||
|
@ -3222,8 +3191,11 @@ paths:
|
|||
content:
|
||||
text/event-stream:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||
description: OK
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/Turn'
|
||||
- $ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||
description: A single turn in an interaction with an Agentic System. **OR**
|
||||
streamed agent turn completion response.
|
||||
tags:
|
||||
- Agents
|
||||
/agents/turn/get:
|
||||
|
@ -4122,7 +4094,7 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ScoringFunctionDefWithProvider'
|
||||
- $ref: '#/components/schemas/ScoringFnDefWithProvider'
|
||||
- type: 'null'
|
||||
description: OK
|
||||
tags:
|
||||
|
@ -4142,7 +4114,7 @@ paths:
|
|||
content:
|
||||
application/jsonl:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ScoringFunctionDefWithProvider'
|
||||
$ref: '#/components/schemas/ScoringFnDefWithProvider'
|
||||
description: OK
|
||||
tags:
|
||||
- ScoringFunctions
|
||||
|
@ -4308,23 +4280,23 @@ security:
|
|||
servers:
|
||||
- url: http://any-hosted-llama-stack.com
|
||||
tags:
|
||||
- name: Eval
|
||||
- name: ScoringFunctions
|
||||
- name: SyntheticDataGeneration
|
||||
- name: Inspect
|
||||
- name: PostTraining
|
||||
- name: Models
|
||||
- name: Safety
|
||||
- name: MemoryBanks
|
||||
- name: DatasetIO
|
||||
- name: Memory
|
||||
- name: Scoring
|
||||
- name: Shields
|
||||
- name: Datasets
|
||||
- name: Inference
|
||||
- name: Telemetry
|
||||
- name: Eval
|
||||
- name: MemoryBanks
|
||||
- name: Models
|
||||
- name: BatchInference
|
||||
- name: PostTraining
|
||||
- name: Agents
|
||||
- name: Shields
|
||||
- name: Telemetry
|
||||
- name: Inspect
|
||||
- name: DatasetIO
|
||||
- name: SyntheticDataGeneration
|
||||
- name: Datasets
|
||||
- name: Scoring
|
||||
- name: ScoringFunctions
|
||||
- name: Safety
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||
name: BuiltinTool
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||
|
@ -4483,8 +4455,11 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseStepStartPayload"
|
||||
/>
|
||||
name: AgentTurnResponseStepStartPayload
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseStreamChunk"
|
||||
/>
|
||||
- description: 'streamed agent turn completion response.
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseStreamChunk"
|
||||
/>'
|
||||
name: AgentTurnResponseStreamChunk
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
||||
/>
|
||||
|
@ -4577,9 +4552,9 @@ tags:
|
|||
name: PaginatedRowsResult
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/Parameter" />
|
||||
name: Parameter
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoringFunctionDefWithProvider"
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoringFnDefWithProvider"
|
||||
/>
|
||||
name: ScoringFunctionDefWithProvider
|
||||
name: ScoringFnDefWithProvider
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldDefWithProvider"
|
||||
/>
|
||||
name: ShieldDefWithProvider
|
||||
|
@ -4844,7 +4819,7 @@ x-tagGroups:
|
|||
- ScoreBatchResponse
|
||||
- ScoreRequest
|
||||
- ScoreResponse
|
||||
- ScoringFunctionDefWithProvider
|
||||
- ScoringFnDefWithProvider
|
||||
- ScoringResult
|
||||
- SearchToolDefinition
|
||||
- Session
|
||||
|
|
|
@ -8,6 +8,7 @@ from datetime import datetime
|
|||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
|
@ -405,6 +406,8 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
"""streamed agent turn completion response."""
|
||||
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
|
@ -434,7 +437,7 @@ class Agents(Protocol):
|
|||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AgentTurnResponseStreamChunk: ...
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/agents/turn/get")
|
||||
async def get_agents_turn(
|
||||
|
|
|
@ -6,7 +6,15 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
from typing import (
|
||||
AsyncIterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -224,7 +232,7 @@ class Inference(Protocol):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
|
@ -239,7 +247,9 @@ class Inference(Protocol):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]: ...
|
||||
|
||||
@webmethod(route="/inference/embeddings")
|
||||
async def embeddings(
|
||||
|
|
|
@ -77,9 +77,9 @@ if [ -n "$LLAMA_STACK_DIR" ]; then
|
|||
# Install in editable format. We will mount the source code into the container
|
||||
# so that changes will be reflected in the container without having to do a
|
||||
# rebuild. This is just for development convenience.
|
||||
add_to_docker "RUN pip install -e $stack_mount"
|
||||
add_to_docker "RUN pip install --no-cache -e $stack_mount"
|
||||
else
|
||||
add_to_docker "RUN pip install llama-stack"
|
||||
add_to_docker "RUN pip install --no-cache llama-stack"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
|
@ -90,19 +90,19 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|||
|
||||
add_to_docker <<EOF
|
||||
RUN pip uninstall -y llama-models
|
||||
RUN pip install $models_mount
|
||||
RUN pip install --no-cache $models_mount
|
||||
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
add_to_docker "RUN pip install $pip_dependencies"
|
||||
add_to_docker "RUN pip install --no-cache $pip_dependencies"
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
add_to_docker "RUN pip install $part"
|
||||
add_to_docker "RUN pip install --no-cache $part"
|
||||
done
|
||||
fi
|
||||
|
||||
|
|
221
llama_stack/distribution/client.py
Normal file
221
llama_stack/distribution/client.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import Any, get_args, get_origin, Type, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(
|
||||
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
|
||||
):
|
||||
client_class = create_api_client_class(protocol, additional_protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
|
||||
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
|
||||
|
||||
# TODO: make this more precise, same thing needs to happen in server.py
|
||||
is_streaming = kwargs.get("stream", False)
|
||||
if is_streaming:
|
||||
return self._call_streaming(method_name, *args, **kwargs)
|
||||
else:
|
||||
return await self._call_non_streaming(method_name, *args, **kwargs)
|
||||
|
||||
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
_, sig = self.routes[method_name]
|
||||
|
||||
if sig.return_annotation is None:
|
||||
return_type = None
|
||||
else:
|
||||
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
||||
assert (
|
||||
return_type
|
||||
), f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
response = await client.request(**params)
|
||||
response.raise_for_status()
|
||||
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return parse_obj_as(return_type, j)
|
||||
|
||||
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
return_type = extract_async_iterator_type(sig.return_annotation)
|
||||
assert (
|
||||
return_type
|
||||
), f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
async with client.stream(**params) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield parse_obj_as(return_type, json.loads(data))
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
parameters = list(sig.parameters.values())[1:] # skip `self`
|
||||
for i, param in enumerate(parameters):
|
||||
if i >= len(args):
|
||||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
url = f"{self.base_url}{webmethod.route}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
return [convert(v) for v in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
elif isinstance(value, Enum):
|
||||
return value.value
|
||||
else:
|
||||
return value
|
||||
|
||||
params = {}
|
||||
data = {}
|
||||
if webmethod.method == "GET":
|
||||
params.update(kwargs)
|
||||
else:
|
||||
data.update(convert(kwargs))
|
||||
|
||||
return dict(
|
||||
method=webmethod.method or "POST",
|
||||
url=url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
params=params,
|
||||
json=data,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
_CLIENT_CLASSES[protocol] = APIClient
|
||||
return APIClient
|
||||
|
||||
|
||||
# not quite general these methods are
|
||||
def extract_non_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if not issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
return arg
|
||||
return type_hint
|
||||
|
||||
|
||||
def extract_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
inner_args = get_args(arg)
|
||||
return inner_args[0]
|
||||
return None
|
||||
|
||||
|
||||
async def example(model: str = None):
|
||||
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
|
||||
from llama_stack.apis.inference.event_logger import EventLogger
|
||||
|
||||
client_class = create_api_client_class(Inference)
|
||||
client = client_class("http://localhost:5003")
|
||||
|
||||
if not model:
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
|
||||
message = UserMessage(
|
||||
content="hello world, write me a 2 sentence poem about the moon"
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
|
||||
stream = True
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(example())
|
|
@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.datasets: Datasets,
|
||||
Api.datasetio: DatasetIO,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.datasets: Datasets,
|
||||
Api.scoring: Scoring,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.eval: Eval,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: ModelsProtocolPrivate,
|
||||
Api.memory: MemoryBanksProtocolPrivate,
|
||||
Api.safety: ShieldsProtocolPrivate,
|
||||
Api.inference: (ModelsProtocolPrivate, Models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
|
||||
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
|
||||
}
|
||||
|
||||
|
||||
|
@ -112,8 +114,6 @@ async def resolve_impls(
|
|||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
||||
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
|
@ -246,14 +246,21 @@ async def instantiate_provider(
|
|||
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
protocol = protocols[provider_spec.api]
|
||||
if provider_spec.api in additional_protocols:
|
||||
_, additional_protocol = additional_protocols[provider_spec.api]
|
||||
else:
|
||||
additional_protocol = None
|
||||
args = [protocol, additional_protocol, config, deps]
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
|
@ -282,7 +289,7 @@ async def instantiate_provider(
|
|||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api = additional_protocols[provider_spec.api]
|
||||
additional_api, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
|
|
@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api:
|
|||
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
# if this is just a passthrough, we want to let the remote
|
||||
# end actually do the registration with the correct provider
|
||||
obj = obj.model_copy(deep=True)
|
||||
obj.provider_id = ""
|
||||
|
||||
if api == Api.inference:
|
||||
await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
|
@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
async def initialize(self) -> None:
|
||||
self.registry: Registry = {}
|
||||
|
||||
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
|
||||
def add_objects(
|
||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||
) -> None:
|
||||
for obj in objs:
|
||||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
if provider_id == "remote":
|
||||
# if this is just a passthrough, we got the *WithProvider object
|
||||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
||||
self.registry[obj.identifier].append(obj)
|
||||
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
|
@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
models = await p.list_models()
|
||||
add_objects(
|
||||
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
|
||||
)
|
||||
add_objects(models, pid, ModelDefWithProvider)
|
||||
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
shields = await p.list_shields()
|
||||
add_objects(
|
||||
[
|
||||
ShieldDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in shields
|
||||
]
|
||||
)
|
||||
add_objects(shields, pid, ShieldDefWithProvider)
|
||||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
memory_banks = await p.list_memory_banks()
|
||||
|
||||
# do in-memory updates due to pesky Annotated unions
|
||||
for m in memory_banks:
|
||||
m.provider_id = pid
|
||||
|
||||
add_objects(memory_banks)
|
||||
add_objects(memory_banks, pid, None)
|
||||
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
datasets = await p.list_datasets()
|
||||
|
||||
# do in-memory updates due to pesky Annotated unions
|
||||
for d in datasets:
|
||||
d.provider_id = pid
|
||||
add_objects(datasets, pid, DatasetDefWithProvider)
|
||||
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
add_objects(
|
||||
[
|
||||
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in scoring_functions
|
||||
]
|
||||
)
|
||||
add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
|
|
@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
|
@ -290,23 +290,130 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = self.map_to_provider_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params_for_chat_completion(request)
|
||||
converse_api_res = self.client.converse(**params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params_for_chat_completion(request)
|
||||
converse_stream_api_res = self.client.converse_stream(**params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
|
||||
"input"
|
||||
]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
|
||||
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||
bedrock_model = self.map_to_provider_model(request.model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
request.sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
|
||||
request.tools, request.tool_choice
|
||||
)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
|
@ -317,93 +424,12 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
if tool_config and not request.stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
||||
return converse_api_params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
@ -75,7 +75,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
for model in self.client.models.list()
|
||||
]
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -86,7 +86,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -111,7 +111,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, self.client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, self.client)
|
||||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
|
|
|
@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol):
|
|||
class DatasetsProtocolPrivate(Protocol):
|
||||
async def list_datasets(self) -> List[DatasetDef]: ...
|
||||
|
||||
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
|
||||
async def register_dataset(self, dataset_def: DatasetDef) -> None: ...
|
||||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
|
@ -171,7 +171,7 @@ as being "Llama Stack compatible"
|
|||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return f"llama_stack.apis.{self.api.value}.client"
|
||||
return "llama_stack.distribution.client"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
|
|
|
@ -81,7 +81,9 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay
|
|||
switch (m.content) {
|
||||
case .case1(let c):
|
||||
prompt += _processContent(c)
|
||||
case .case2(let c):
|
||||
case .ImageMedia(let c):
|
||||
prompt += _processContent(c)
|
||||
case .case3(let c):
|
||||
prompt += _processContent(c)
|
||||
}
|
||||
case .CompletionMessage(let m):
|
||||
|
|
|
@ -26,6 +26,7 @@ from dotenv import load_dotenv
|
|||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# MODEL_ID=<your_model> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/agents/test_agents.py \
|
||||
# --tb=short --disable-warnings
|
||||
|
@ -44,7 +45,7 @@ async def agents_settings():
|
|||
"impl": impls[Api.agents],
|
||||
"memory_impl": impls[Api.memory],
|
||||
"common_params": {
|
||||
"model": "Llama3.1-8B-Instruct",
|
||||
"model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
|||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
provider_id=os.environ["PROVIDER_ID"],
|
||||
)
|
||||
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue