Merge branch 'main' into docs_improvement

This commit is contained in:
Kai Wu 2024-11-04 12:40:18 -08:00
commit 1422d631a8
15 changed files with 554 additions and 340 deletions

View file

@ -315,7 +315,20 @@ def get_endpoint_operations(
) )
else: else:
event_type = None event_type = None
response_type = return_type
def process_type(t):
if typing.get_origin(t) is collections.abc.AsyncIterator:
# NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List
# or the item type. I am choosing it to be the latter
args = typing.get_args(t)
return args[0]
elif typing.get_origin(t) is typing.Union:
types = [process_type(a) for a in typing.get_args(t)]
return typing._UnionGenericAlias(typing.Union, tuple(types))
else:
return t
response_type = process_type(return_type)
# set HTTP request method based on type of request and presence of payload # set HTTP request method based on type of request and presence of payload
if not request_params: if not request_params:

View file

@ -21,7 +21,7 @@
"info": { "info": {
"title": "[DRAFT] Llama Stack Specification", "title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1", "version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117" "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
}, },
"servers": [ "servers": [
{ {
@ -320,11 +320,18 @@
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
"description": "OK", "description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.",
"content": { "content": {
"text/event-stream": { "text/event-stream": {
"schema": { "schema": {
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk" "oneOf": [
{
"$ref": "#/components/schemas/Turn"
},
{
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
}
]
} }
} }
} }
@ -934,7 +941,7 @@
"schema": { "schema": {
"oneOf": [ "oneOf": [
{ {
"$ref": "#/components/schemas/ScoringFunctionDefWithProvider" "$ref": "#/components/schemas/ScoringFnDefWithProvider"
}, },
{ {
"type": "null" "type": "null"
@ -1555,7 +1562,7 @@
"content": { "content": {
"application/jsonl": { "application/jsonl": {
"schema": { "schema": {
"$ref": "#/components/schemas/ScoringFunctionDefWithProvider" "$ref": "#/components/schemas/ScoringFnDefWithProvider"
} }
} }
} }
@ -2762,7 +2769,7 @@
"const": "json_schema", "const": "json_schema",
"default": "json_schema" "default": "json_schema"
}, },
"schema": { "json_schema": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
"oneOf": [ "oneOf": [
@ -2791,7 +2798,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type",
"schema" "json_schema"
] ]
}, },
{ {
@ -3018,7 +3025,7 @@
"const": "json_schema", "const": "json_schema",
"default": "json_schema" "default": "json_schema"
}, },
"schema": { "json_schema": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
"oneOf": [ "oneOf": [
@ -3047,7 +3054,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type",
"schema" "json_schema"
] ]
}, },
{ {
@ -4002,7 +4009,8 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"event" "event"
] ],
"title": "streamed agent turn completion response."
}, },
"AgentTurnResponseTurnCompletePayload": { "AgentTurnResponseTurnCompletePayload": {
"type": "object", "type": "object",
@ -5004,24 +5012,6 @@
"type" "type"
] ]
}, },
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "custom",
"default": "custom"
},
"validator_class": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"validator_class"
]
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -5304,24 +5294,6 @@
"type" "type"
] ]
}, },
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "custom",
"default": "custom"
},
"validator_class": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"validator_class"
]
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -5376,7 +5348,7 @@
"type" "type"
] ]
}, },
"ScoringFunctionDefWithProvider": { "ScoringFnDefWithProvider": {
"type": "object", "type": "object",
"properties": { "properties": {
"identifier": { "identifier": {
@ -5516,24 +5488,6 @@
"type" "type"
] ]
}, },
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "custom",
"default": "custom"
},
"validator_class": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"validator_class"
]
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -5586,6 +5540,12 @@
}, },
"prompt_template": { "prompt_template": {
"type": "string" "type": "string"
},
"judge_score_regex": {
"type": "array",
"items": {
"type": "string"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -6339,10 +6299,10 @@
"finetuned_model": { "finetuned_model": {
"$ref": "#/components/schemas/URL" "$ref": "#/components/schemas/URL"
}, },
"dataset": { "dataset_id": {
"type": "string" "type": "string"
}, },
"validation_dataset": { "validation_dataset_id": {
"type": "string" "type": "string"
}, },
"algorithm": { "algorithm": {
@ -6412,8 +6372,8 @@
"required": [ "required": [
"job_uuid", "job_uuid",
"finetuned_model", "finetuned_model",
"dataset", "dataset_id",
"validation_dataset", "validation_dataset_id",
"algorithm", "algorithm",
"algorithm_config", "algorithm_config",
"optimizer_config", "optimizer_config",
@ -6595,7 +6555,7 @@
"type": "object", "type": "object",
"properties": { "properties": {
"function_def": { "function_def": {
"$ref": "#/components/schemas/ScoringFunctionDefWithProvider" "$ref": "#/components/schemas/ScoringFnDefWithProvider"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -6893,10 +6853,10 @@
"model": { "model": {
"type": "string" "type": "string"
}, },
"dataset": { "dataset_id": {
"type": "string" "type": "string"
}, },
"validation_dataset": { "validation_dataset_id": {
"type": "string" "type": "string"
}, },
"algorithm": { "algorithm": {
@ -6976,8 +6936,8 @@
"required": [ "required": [
"job_uuid", "job_uuid",
"model", "model",
"dataset", "dataset_id",
"validation_dataset", "validation_dataset_id",
"algorithm", "algorithm",
"algorithm_config", "algorithm_config",
"optimizer_config", "optimizer_config",
@ -7102,57 +7062,57 @@
} }
], ],
"tags": [ "tags": [
{
"name": "Eval"
},
{
"name": "ScoringFunctions"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Inspect"
},
{
"name": "PostTraining"
},
{
"name": "Models"
},
{
"name": "Safety"
},
{
"name": "MemoryBanks"
},
{
"name": "DatasetIO"
},
{ {
"name": "Memory" "name": "Memory"
}, },
{
"name": "Scoring"
},
{
"name": "Shields"
},
{
"name": "Datasets"
},
{ {
"name": "Inference" "name": "Inference"
}, },
{ {
"name": "Telemetry" "name": "Eval"
},
{
"name": "MemoryBanks"
},
{
"name": "Models"
}, },
{ {
"name": "BatchInference" "name": "BatchInference"
}, },
{
"name": "PostTraining"
},
{ {
"name": "Agents" "name": "Agents"
}, },
{
"name": "Shields"
},
{
"name": "Telemetry"
},
{
"name": "Inspect"
},
{
"name": "DatasetIO"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Datasets"
},
{
"name": "Scoring"
},
{
"name": "ScoringFunctions"
},
{
"name": "Safety"
},
{ {
"name": "BuiltinTool", "name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -7355,7 +7315,7 @@
}, },
{ {
"name": "AgentTurnResponseStreamChunk", "name": "AgentTurnResponseStreamChunk",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AgentTurnResponseStreamChunk\" />" "description": "streamed agent turn completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/AgentTurnResponseStreamChunk\" />"
}, },
{ {
"name": "AgentTurnResponseTurnCompletePayload", "name": "AgentTurnResponseTurnCompletePayload",
@ -7486,8 +7446,8 @@
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Parameter\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/Parameter\" />"
}, },
{ {
"name": "ScoringFunctionDefWithProvider", "name": "ScoringFnDefWithProvider",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ScoringFunctionDefWithProvider\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/ScoringFnDefWithProvider\" />"
}, },
{ {
"name": "ShieldDefWithProvider", "name": "ShieldDefWithProvider",
@ -7805,7 +7765,7 @@
"ScoreBatchResponse", "ScoreBatchResponse",
"ScoreRequest", "ScoreRequest",
"ScoreResponse", "ScoreResponse",
"ScoringFunctionDefWithProvider", "ScoringFnDefWithProvider",
"ScoringResult", "ScoringResult",
"SearchToolDefinition", "SearchToolDefinition",
"Session", "Session",

View file

@ -190,6 +190,7 @@ components:
$ref: '#/components/schemas/AgentTurnResponseEvent' $ref: '#/components/schemas/AgentTurnResponseEvent'
required: required:
- event - event
title: streamed agent turn completion response.
type: object type: object
AgentTurnResponseTurnCompletePayload: AgentTurnResponseTurnCompletePayload:
additionalProperties: false additionalProperties: false
@ -360,7 +361,7 @@ components:
oneOf: oneOf:
- additionalProperties: false - additionalProperties: false
properties: properties:
schema: json_schema:
additionalProperties: additionalProperties:
oneOf: oneOf:
- type: 'null' - type: 'null'
@ -376,7 +377,7 @@ components:
type: string type: string
required: required:
- type - type
- schema - json_schema
type: object type: object
- additionalProperties: false - additionalProperties: false
properties: properties:
@ -541,7 +542,7 @@ components:
oneOf: oneOf:
- additionalProperties: false - additionalProperties: false
properties: properties:
schema: json_schema:
additionalProperties: additionalProperties:
oneOf: oneOf:
- type: 'null' - type: 'null'
@ -557,7 +558,7 @@ components:
type: string type: string
required: required:
- type - type
- schema - json_schema
type: object type: object
- additionalProperties: false - additionalProperties: false
properties: properties:
@ -747,18 +748,6 @@ components:
required: required:
- type - type
type: object type: object
- additionalProperties: false
properties:
type:
const: custom
default: custom
type: string
validator_class:
type: string
required:
- type
- validator_class
type: object
- additionalProperties: false - additionalProperties: false
properties: properties:
type: type:
@ -1575,18 +1564,6 @@ components:
required: required:
- type - type
type: object type: object
- additionalProperties: false
properties:
type:
const: custom
default: custom
type: string
validator_class:
type: string
required:
- type
- validator_class
type: object
- additionalProperties: false - additionalProperties: false
properties: properties:
type: type:
@ -1724,7 +1701,7 @@ components:
$ref: '#/components/schemas/RLHFAlgorithm' $ref: '#/components/schemas/RLHFAlgorithm'
algorithm_config: algorithm_config:
$ref: '#/components/schemas/DPOAlignmentConfig' $ref: '#/components/schemas/DPOAlignmentConfig'
dataset: dataset_id:
type: string type: string
finetuned_model: finetuned_model:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/URL'
@ -1754,13 +1731,13 @@ components:
$ref: '#/components/schemas/OptimizerConfig' $ref: '#/components/schemas/OptimizerConfig'
training_config: training_config:
$ref: '#/components/schemas/TrainingConfig' $ref: '#/components/schemas/TrainingConfig'
validation_dataset: validation_dataset_id:
type: string type: string
required: required:
- job_uuid - job_uuid
- finetuned_model - finetuned_model
- dataset - dataset_id
- validation_dataset - validation_dataset_id
- algorithm - algorithm
- algorithm_config - algorithm_config
- optimizer_config - optimizer_config
@ -1899,7 +1876,7 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
function_def: function_def:
$ref: '#/components/schemas/ScoringFunctionDefWithProvider' $ref: '#/components/schemas/ScoringFnDefWithProvider'
required: required:
- function_def - function_def
type: object type: object
@ -2121,7 +2098,7 @@ components:
required: required:
- results - results
type: object type: object
ScoringFunctionDefWithProvider: ScoringFnDefWithProvider:
additionalProperties: false additionalProperties: false
properties: properties:
context: context:
@ -2129,6 +2106,10 @@ components:
properties: properties:
judge_model: judge_model:
type: string type: string
judge_score_regex:
items:
type: string
type: array
prompt_template: prompt_template:
type: string type: string
required: required:
@ -2219,18 +2200,6 @@ components:
required: required:
- type - type
type: object type: object
- additionalProperties: false
properties:
type:
const: custom
default: custom
type: string
validator_class:
type: string
required:
- type
- validator_class
type: object
- additionalProperties: false - additionalProperties: false
properties: properties:
type: type:
@ -2484,7 +2453,7 @@ components:
- $ref: '#/components/schemas/LoraFinetuningConfig' - $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QLoraFinetuningConfig' - $ref: '#/components/schemas/QLoraFinetuningConfig'
- $ref: '#/components/schemas/DoraFinetuningConfig' - $ref: '#/components/schemas/DoraFinetuningConfig'
dataset: dataset_id:
type: string type: string
hyperparam_search_config: hyperparam_search_config:
additionalProperties: additionalProperties:
@ -2514,13 +2483,13 @@ components:
$ref: '#/components/schemas/OptimizerConfig' $ref: '#/components/schemas/OptimizerConfig'
training_config: training_config:
$ref: '#/components/schemas/TrainingConfig' $ref: '#/components/schemas/TrainingConfig'
validation_dataset: validation_dataset_id:
type: string type: string
required: required:
- job_uuid - job_uuid
- model - model
- dataset - dataset_id
- validation_dataset - validation_dataset_id
- algorithm - algorithm
- algorithm_config - algorithm_config
- optimizer_config - optimizer_config
@ -3029,7 +2998,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117" \ draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3222,8 +3191,11 @@ paths:
content: content:
text/event-stream: text/event-stream:
schema: schema:
$ref: '#/components/schemas/AgentTurnResponseStreamChunk' oneOf:
description: OK - $ref: '#/components/schemas/Turn'
- $ref: '#/components/schemas/AgentTurnResponseStreamChunk'
description: A single turn in an interaction with an Agentic System. **OR**
streamed agent turn completion response.
tags: tags:
- Agents - Agents
/agents/turn/get: /agents/turn/get:
@ -4122,7 +4094,7 @@ paths:
application/json: application/json:
schema: schema:
oneOf: oneOf:
- $ref: '#/components/schemas/ScoringFunctionDefWithProvider' - $ref: '#/components/schemas/ScoringFnDefWithProvider'
- type: 'null' - type: 'null'
description: OK description: OK
tags: tags:
@ -4142,7 +4114,7 @@ paths:
content: content:
application/jsonl: application/jsonl:
schema: schema:
$ref: '#/components/schemas/ScoringFunctionDefWithProvider' $ref: '#/components/schemas/ScoringFnDefWithProvider'
description: OK description: OK
tags: tags:
- ScoringFunctions - ScoringFunctions
@ -4308,23 +4280,23 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: Eval
- name: ScoringFunctions
- name: SyntheticDataGeneration
- name: Inspect
- name: PostTraining
- name: Models
- name: Safety
- name: MemoryBanks
- name: DatasetIO
- name: Memory - name: Memory
- name: Scoring
- name: Shields
- name: Datasets
- name: Inference - name: Inference
- name: Telemetry - name: Eval
- name: MemoryBanks
- name: Models
- name: BatchInference - name: BatchInference
- name: PostTraining
- name: Agents - name: Agents
- name: Shields
- name: Telemetry
- name: Inspect
- name: DatasetIO
- name: SyntheticDataGeneration
- name: Datasets
- name: Scoring
- name: ScoringFunctions
- name: Safety
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -4483,8 +4455,11 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseStepStartPayload" - description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseStepStartPayload"
/> />
name: AgentTurnResponseStepStartPayload name: AgentTurnResponseStepStartPayload
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseStreamChunk" - description: 'streamed agent turn completion response.
/>
<SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseStreamChunk"
/>'
name: AgentTurnResponseStreamChunk name: AgentTurnResponseStreamChunk
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseTurnCompletePayload" - description: <SchemaDefinition schemaRef="#/components/schemas/AgentTurnResponseTurnCompletePayload"
/> />
@ -4577,9 +4552,9 @@ tags:
name: PaginatedRowsResult name: PaginatedRowsResult
- description: <SchemaDefinition schemaRef="#/components/schemas/Parameter" /> - description: <SchemaDefinition schemaRef="#/components/schemas/Parameter" />
name: Parameter name: Parameter
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoringFunctionDefWithProvider" - description: <SchemaDefinition schemaRef="#/components/schemas/ScoringFnDefWithProvider"
/> />
name: ScoringFunctionDefWithProvider name: ScoringFnDefWithProvider
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldDefWithProvider" - description: <SchemaDefinition schemaRef="#/components/schemas/ShieldDefWithProvider"
/> />
name: ShieldDefWithProvider name: ShieldDefWithProvider
@ -4844,7 +4819,7 @@ x-tagGroups:
- ScoreBatchResponse - ScoreBatchResponse
- ScoreRequest - ScoreRequest
- ScoreResponse - ScoreResponse
- ScoringFunctionDefWithProvider - ScoringFnDefWithProvider
- ScoringResult - ScoringResult
- SearchToolDefinition - SearchToolDefinition
- Session - Session

View file

@ -8,6 +8,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Dict, Dict,
List, List,
Literal, Literal,
@ -405,6 +406,8 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
@json_schema_type @json_schema_type
class AgentTurnResponseStreamChunk(BaseModel): class AgentTurnResponseStreamChunk(BaseModel):
"""streamed agent turn completion response."""
event: AgentTurnResponseEvent event: AgentTurnResponseEvent
@ -434,7 +437,7 @@ class Agents(Protocol):
], ],
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AgentTurnResponseStreamChunk: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get") @webmethod(route="/agents/turn/get")
async def get_agents_turn( async def get_agents_turn(

View file

@ -6,7 +6,15 @@
from enum import Enum from enum import Enum
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union from typing import (
AsyncIterator,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -224,7 +232,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
async def chat_completion( async def chat_completion(
@ -239,7 +247,9 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ...
@webmethod(route="/inference/embeddings") @webmethod(route="/inference/embeddings")
async def embeddings( async def embeddings(

View file

@ -77,9 +77,9 @@ if [ -n "$LLAMA_STACK_DIR" ]; then
# Install in editable format. We will mount the source code into the container # Install in editable format. We will mount the source code into the container
# so that changes will be reflected in the container without having to do a # so that changes will be reflected in the container without having to do a
# rebuild. This is just for development convenience. # rebuild. This is just for development convenience.
add_to_docker "RUN pip install -e $stack_mount" add_to_docker "RUN pip install --no-cache -e $stack_mount"
else else
add_to_docker "RUN pip install llama-stack" add_to_docker "RUN pip install --no-cache llama-stack"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then
@ -90,19 +90,19 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
add_to_docker <<EOF add_to_docker <<EOF
RUN pip uninstall -y llama-models RUN pip uninstall -y llama-models
RUN pip install $models_mount RUN pip install --no-cache $models_mount
EOF EOF
fi fi
if [ -n "$pip_dependencies" ]; then if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install $pip_dependencies" add_to_docker "RUN pip install --no-cache $pip_dependencies"
fi fi
if [ -n "$special_pip_deps" ]; then if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do for part in "${parts[@]}"; do
add_to_docker "RUN pip install $part" add_to_docker "RUN pip install --no-cache $part"
done done
fi fi

View file

@ -0,0 +1,221 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, get_args, get_origin, Type, Union
import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
):
client_class = create_api_client_class(protocol, additional_protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol, additional_protocol) -> Type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
self.base_url = base_url.rstrip("/")
self.routes = {}
# Store routes for this protocol
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
async def shutdown(self):
pass
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
# TODO: make this more precise, same thing needs to happen in server.py
is_streaming = kwargs.get("stream", False)
if is_streaming:
return self._call_streaming(method_name, *args, **kwargs)
else:
return await self._call_non_streaming(method_name, *args, **kwargs)
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
_, sig = self.routes[method_name]
if sig.return_annotation is None:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert (
return_type
), f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
response = await client.request(**params)
response.raise_for_status()
j = response.json()
if j is None:
return None
return parse_obj_as(return_type, j)
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation)
assert (
return_type
), f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
async with client.stream(**params) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if "error" in data:
cprint(data, "red")
continue
yield parse_obj_as(return_type, json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]
parameters = list(sig.parameters.values())[1:] # skip `self`
for i, param in enumerate(parameters):
if i >= len(args):
break
kwargs[param.name] = args[i]
url = f"{self.base_url}{webmethod.route}"
def convert(value):
if isinstance(value, list):
return [convert(v) for v in value]
elif isinstance(value, dict):
return {k: convert(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
else:
return value
params = {}
data = {}
if webmethod.method == "GET":
params.update(kwargs)
else:
data.update(convert(kwargs))
return dict(
method=webmethod.method or "POST",
url=url,
headers={"Content-Type": "application/json"},
params=params,
json=data,
timeout=30,
)
# Add protocol methods to the wrapper
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"
_CLIENT_CLASSES[protocol] = APIClient
return APIClient
# not quite general these methods are
def extract_non_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if not issubclass(get_origin(arg) or arg, AsyncIterator):
return arg
return type_hint
def extract_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if issubclass(get_origin(arg) or arg, AsyncIterator):
inner_args = get_args(arg)
return inner_args[0]
return None
async def example(model: str = None):
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
from llama_stack.apis.inference.event_logger import EventLogger
client_class = create_api_client_class(Inference)
client = client_class("http://localhost:5003")
if not model:
model = "Llama3.2-3B-Instruct"
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
cprint(f"User>{message.content}", "green")
stream = True
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
if __name__ == "__main__":
import asyncio
asyncio.run(example())

View file

@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.safety: Safety, Api.safety: Safety,
Api.shields: Shields, Api.shields: Shields,
Api.telemetry: Telemetry, Api.telemetry: Telemetry,
Api.datasets: Datasets,
Api.datasetio: DatasetIO, Api.datasetio: DatasetIO,
Api.scoring_functions: ScoringFunctions, Api.datasets: Datasets,
Api.scoring: Scoring, Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval, Api.eval: Eval,
} }
def additional_protocols_map() -> Dict[Api, Any]: def additional_protocols_map() -> Dict[Api, Any]:
return { return {
Api.inference: ModelsProtocolPrivate, Api.inference: (ModelsProtocolPrivate, Models),
Api.memory: MemoryBanksProtocolPrivate, Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
Api.safety: ShieldsProtocolPrivate, Api.safety: (ShieldsProtocolPrivate, Shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
} }
@ -112,8 +114,6 @@ async def resolve_impls(
if info.router_api.value not in apis_to_serve: if info.router_api.value not in apis_to_serve:
continue continue
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
providers_with_specs[info.routing_table_api.value] = { providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec( "__builtin__": ProviderWithSpec(
provider_id="__routing_table__", provider_id="__routing_table__",
@ -246,14 +246,21 @@ async def instantiate_provider(
args = [] args = []
if isinstance(provider_spec, RemoteProviderSpec): if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config) config = config_type(**provider.config)
args = [config, deps]
if provider_spec.adapter:
method = "get_adapter_impl"
args = [config, deps]
else:
method = "get_client_impl"
protocol = protocols[provider_spec.api]
if provider_spec.api in additional_protocols:
_, additional_protocol = additional_protocols[provider_spec.api]
else:
additional_protocol = None
args = [protocol, additional_protocol, config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"
@ -282,7 +289,7 @@ async def instantiate_provider(
not isinstance(provider_spec, AutoRoutedProviderSpec) not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols and provider_spec.api in additional_protocols
): ):
additional_api = additional_protocols[provider_spec.api] additional_api, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api) check_protocol_compliance(impl, additional_api)
return impl return impl

View file

@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api:
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p) api = get_impl_api(p)
if obj.provider_id == "remote":
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
obj.provider_id = ""
if api == Api.inference: if api == Api.inference:
await p.register_model(obj) await p.register_model(obj)
elif api == Api.safety: elif api == Api.safety:
@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable):
async def initialize(self) -> None: async def initialize(self) -> None:
self.registry: Registry = {} self.registry: Registry = {}
def add_objects(objs: List[RoutableObjectWithProvider]) -> None: def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs: for obj in objs:
if obj.identifier not in self.registry: if obj.identifier not in self.registry:
self.registry[obj.identifier] = [] self.registry[obj.identifier] = []
if cls is None:
obj.provider_id = provider_id
else:
if provider_id == "remote":
# if this is just a passthrough, we got the *WithProvider object
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
obj = cls(**obj.model_dump(), provider_id=provider_id)
self.registry[obj.identifier].append(obj) self.registry[obj.identifier].append(obj)
for pid, p in self.impls_by_provider_id.items(): for pid, p in self.impls_by_provider_id.items():
@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable):
if api == Api.inference: if api == Api.inference:
p.model_store = self p.model_store = self
models = await p.list_models() models = await p.list_models()
add_objects( add_objects(models, pid, ModelDefWithProvider)
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
)
elif api == Api.safety: elif api == Api.safety:
p.shield_store = self p.shield_store = self
shields = await p.list_shields() shields = await p.list_shields()
add_objects( add_objects(shields, pid, ShieldDefWithProvider)
[
ShieldDefWithProvider(**s.dict(), provider_id=pid)
for s in shields
]
)
elif api == Api.memory: elif api == Api.memory:
p.memory_bank_store = self p.memory_bank_store = self
memory_banks = await p.list_memory_banks() memory_banks = await p.list_memory_banks()
add_objects(memory_banks, pid, None)
# do in-memory updates due to pesky Annotated unions
for m in memory_banks:
m.provider_id = pid
add_objects(memory_banks)
elif api == Api.datasetio: elif api == Api.datasetio:
p.dataset_store = self p.dataset_store = self
datasets = await p.list_datasets() datasets = await p.list_datasets()
add_objects(datasets, pid, DatasetDefWithProvider)
# do in-memory updates due to pesky Annotated unions
for d in datasets:
d.provider_id = pid
elif api == Api.scoring: elif api == Api.scoring:
p.scoring_function_store = self p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions() scoring_functions = await p.list_scoring_functions()
add_objects( add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
[
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
for s in scoring_functions
]
)
async def shutdown(self) -> None: async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():

View file

@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@ -290,23 +290,130 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> ( ) -> Union[
AsyncGenerator ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ]:
bedrock_model = self.map_to_provider_model(model) request = ChatCompletionRequest(
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( model=model,
sampling_params messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
) )
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_chat_completion(request)
converse_api_res = self.client.converse(**params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res
)
return ChatCompletionResponse(
completion_message=output_message,
logprobs=None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params_for_chat_completion(request)
converse_stream_api_res = self.client.converse_stream(**params)
event_stream = converse_stream_api_res["stream"]
for chunk in event_stream:
if "messageStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
),
parse_status=ToolCallParseStatus.started,
),
)
)
elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"]
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
"input"
]
),
parse_status=ToolCallParseStatus.success,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
)
)
elif "contentBlockStop" in chunk:
# Ignored
pass
elif "messageStop" in chunk:
stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
elif "metadata" in chunk:
# Ignored
pass
else:
# Ignored
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = self.map_to_provider_model(request.model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
request.tools, request.tool_choice
)
bedrock_messages, system_bedrock_messages = ( bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(messages) BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
) )
converse_api_params = { converse_api_params = {
@ -317,93 +424,12 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
converse_api_params["inferenceConfig"] = inference_config converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode # Tool use is not supported in streaming mode
if tool_config and not stream: if tool_config and not request.stream:
converse_api_params["toolConfig"] = tool_config converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages: if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages converse_api_params["system"] = system_bedrock_messages
if not stream: return converse_api_params
converse_api_res = self.client.converse(**converse_api_params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res
)
yield ChatCompletionResponse(
completion_message=output_message,
logprobs=None,
)
else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
event_stream = converse_stream_api_res["stream"]
for chunk in event_stream:
if "messageStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][
"name"
],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
),
parse_status=ToolCallParseStatus.started,
),
)
)
elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"]
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][
"toolUse"
]["input"]
),
parse_status=ToolCallParseStatus.success,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
)
)
elif "contentBlockStop" in chunk:
# Ignored
pass
elif "messageStop" in chunk:
stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
elif "metadata" in chunk:
# Ignored
pass
else:
# Ignored
pass
async def embeddings( async def embeddings(
self, self,

View file

@ -75,7 +75,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
for model in self.client.models.list() for model in self.client.models.list()
] ]
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -86,7 +86,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -111,7 +111,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if stream: if stream:
return self._stream_chat_completion(request, self.client) return self._stream_chat_completion(request, self.client)
else: else:
return self._nonstream_chat_completion(request, self.client) return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI self, request: ChatCompletionRequest, client: OpenAI

View file

@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol):
async def list_datasets(self) -> List[DatasetDef]: ... async def list_datasets(self) -> List[DatasetDef]: ...
async def register_datasets(self, dataset_def: DatasetDef) -> None: ... async def register_dataset(self, dataset_def: DatasetDef) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol): class ScoringFunctionsProtocolPrivate(Protocol):
@ -171,7 +171,7 @@ as being "Llama Stack compatible"
def module(self) -> str: def module(self) -> str:
if self.adapter: if self.adapter:
return self.adapter.module return self.adapter.module
return f"llama_stack.apis.{self.api.value}.client" return "llama_stack.distribution.client"
@property @property
def pip_packages(self) -> List[str]: def pip_packages(self) -> List[str]:

View file

@ -81,7 +81,9 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay
switch (m.content) { switch (m.content) {
case .case1(let c): case .case1(let c):
prompt += _processContent(c) prompt += _processContent(c)
case .case2(let c): case .ImageMedia(let c):
prompt += _processContent(c)
case .case3(let c):
prompt += _processContent(c) prompt += _processContent(c)
} }
case .CompletionMessage(let m): case .CompletionMessage(let m):

View file

@ -26,6 +26,7 @@ from dotenv import load_dotenv
# #
# ```bash # ```bash
# PROVIDER_ID=<your_provider> \ # PROVIDER_ID=<your_provider> \
# MODEL_ID=<your_model> \
# PROVIDER_CONFIG=provider_config.yaml \ # PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/agents/test_agents.py \ # pytest -s llama_stack/providers/tests/agents/test_agents.py \
# --tb=short --disable-warnings # --tb=short --disable-warnings
@ -44,7 +45,7 @@ async def agents_settings():
"impl": impls[Api.agents], "impl": impls[Api.agents],
"memory_impl": impls[Api.memory], "memory_impl": impls[Api.memory],
"common_params": { "common_params": {
"model": "Llama3.1-8B-Instruct", "model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.", "instructions": "You are a helpful assistant.",
}, },
} }

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
provider_id=os.environ["PROVIDER_ID"],
) )
await banks_impl.register_memory_bank(bank) await banks_impl.register_memory_bank(bank)