forked from phoenix-oss/llama-stack-mirror
This is yet another of those large PRs (hopefully we will have less and less of them as things mature fast). This one introduces substantial improvements and some simplifications to the stack. Most important bits: * Agents reference implementation now has support for session / turn persistence. The default implementation uses sqlite but there's also support for using Redis. * We have re-architected the structure of the Stack APIs to allow for more flexible routing. The motivating use cases are: - routing model A to ollama and model B to a remote provider like Together - routing shield A to local impl while shield B to a remote provider like Bedrock - routing a vector memory bank to Weaviate while routing a keyvalue memory bank to Redis * Support for provider specific parameters to be passed from the clients. A client can pass data using `x_llamastack_provider_data` parameter which can be type-checked and provided to the Adapter implementations.
316 lines
12 KiB
Python
316 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import inspect
|
|
import json
|
|
from enum import Enum
|
|
|
|
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
|
|
|
from pydantic import BaseModel
|
|
from pydantic.fields import FieldInfo
|
|
from pydantic_core import PydanticUndefinedType
|
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
def is_list_of_primitives(field_type):
|
|
"""Check if a field type is a List of primitive types."""
|
|
origin = get_origin(field_type)
|
|
if origin is List or origin is list:
|
|
args = get_args(field_type)
|
|
if len(args) == 1 and args[0] in (int, float, str, bool):
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_basemodel_without_fields(typ):
|
|
return (
|
|
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
|
)
|
|
|
|
|
|
def can_recurse(typ):
|
|
return (
|
|
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
|
)
|
|
|
|
|
|
def get_literal_values(field):
|
|
"""Extract literal values from a field if it's a Literal type."""
|
|
if get_origin(field.annotation) is Literal:
|
|
return get_args(field.annotation)
|
|
return None
|
|
|
|
|
|
def is_optional(field_type):
|
|
"""Check if a field type is Optional."""
|
|
return get_origin(field_type) is Union and type(None) in get_args(field_type)
|
|
|
|
|
|
def get_non_none_type(field_type):
|
|
"""Get the non-None type from an Optional type."""
|
|
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
|
|
|
|
|
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
|
|
validators = model.__pydantic_decorators__.field_validators
|
|
for _name, validator in validators.items():
|
|
if field_name in validator.info.fields:
|
|
validator.func(value)
|
|
|
|
return value
|
|
|
|
|
|
def is_discriminated_union(typ) -> bool:
|
|
if isinstance(typ, FieldInfo):
|
|
return typ.discriminator
|
|
else:
|
|
if not (get_origin(typ) is Annotated):
|
|
return False
|
|
args = get_args(typ)
|
|
return len(args) >= 2 and args[1].discriminator
|
|
|
|
|
|
def prompt_for_discriminated_union(
|
|
field_name,
|
|
typ,
|
|
existing_value,
|
|
):
|
|
if isinstance(typ, FieldInfo):
|
|
inner_type = typ.annotation
|
|
discriminator = typ.discriminator
|
|
default_value = typ.default
|
|
else:
|
|
args = get_args(typ)
|
|
inner_type = args[0]
|
|
discriminator = args[1].discriminator
|
|
default_value = args[1].default
|
|
|
|
union_types = get_args(inner_type)
|
|
# Find the discriminator field in each union type
|
|
type_map = {}
|
|
for t in union_types:
|
|
disc_field = t.__fields__[discriminator]
|
|
literal_values = get_literal_values(disc_field)
|
|
if literal_values:
|
|
for value in literal_values:
|
|
type_map[value] = t
|
|
|
|
while True:
|
|
prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})"
|
|
if default_value is not None:
|
|
prompt += f" (default: {default_value})"
|
|
|
|
discriminator_value = input(f"{prompt}: ")
|
|
if discriminator_value == "" and default_value is not None:
|
|
discriminator_value = default_value
|
|
|
|
if discriminator_value in type_map:
|
|
chosen_type = type_map[discriminator_value]
|
|
print(f"\nConfiguring {chosen_type.__name__}:")
|
|
|
|
if existing_value and (
|
|
getattr(existing_value, discriminator) != discriminator_value
|
|
):
|
|
existing_value = None
|
|
|
|
sub_config = prompt_for_config(chosen_type, existing_value)
|
|
# Set the discriminator field in the sub-config
|
|
setattr(sub_config, discriminator, discriminator_value)
|
|
return sub_config
|
|
else:
|
|
print(f"Invalid {discriminator}. Please try again.")
|
|
|
|
|
|
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
|
# We should add handling for the most common cases to tide us over.
|
|
#
|
|
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
|
# unit tests for coverage.
|
|
def prompt_for_config(
|
|
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
|
|
) -> BaseModel:
|
|
"""
|
|
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
|
|
|
Args:
|
|
config_type: A Pydantic BaseModel class representing the configuration structure.
|
|
|
|
Returns:
|
|
An instance of the config_type with user-provided values.
|
|
"""
|
|
config_data = {}
|
|
|
|
for field_name, field in config_type.__fields__.items():
|
|
field_type = field.annotation
|
|
existing_value = (
|
|
getattr(existing_config, field_name) if existing_config else None
|
|
)
|
|
if existing_value:
|
|
default_value = existing_value
|
|
else:
|
|
default_value = (
|
|
field.default
|
|
if not isinstance(field.default, PydanticUndefinedType)
|
|
else None
|
|
)
|
|
is_required = field.is_required
|
|
|
|
# Skip fields with Literal type
|
|
if get_origin(field_type) is Literal:
|
|
continue
|
|
|
|
# Skip fields with no type annotations
|
|
if is_basemodel_without_fields(field_type):
|
|
config_data[field_name] = field_type()
|
|
continue
|
|
|
|
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
|
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
|
while True:
|
|
# this branch does not handle existing and default values yet
|
|
user_input = input(prompt + " ")
|
|
try:
|
|
value = field_type[user_input]
|
|
validated_value = manually_validate_field(config_type, field, value)
|
|
config_data[field_name] = validated_value
|
|
break
|
|
except KeyError:
|
|
print(
|
|
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
|
)
|
|
continue
|
|
|
|
if is_discriminated_union(field):
|
|
config_data[field_name] = prompt_for_discriminated_union(
|
|
field_name, field, existing_value
|
|
)
|
|
continue
|
|
|
|
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
|
prompt = f"Do you want to configure {field_name}? (y/n): "
|
|
if input(prompt).lower() == "n":
|
|
config_data[field_name] = None
|
|
continue
|
|
nested_type = get_non_none_type(field_type)
|
|
print(f"Entering sub-configuration for {field_name}:")
|
|
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
|
elif is_optional(field_type) and is_discriminated_union(
|
|
get_non_none_type(field_type)
|
|
):
|
|
prompt = f"Do you want to configure {field_name}? (y/n): "
|
|
if input(prompt).lower() == "n":
|
|
config_data[field_name] = None
|
|
continue
|
|
nested_type = get_non_none_type(field_type)
|
|
config_data[field_name] = prompt_for_discriminated_union(
|
|
field_name,
|
|
nested_type,
|
|
existing_value,
|
|
)
|
|
elif can_recurse(field_type):
|
|
print(f"\nEntering sub-configuration for {field_name}:")
|
|
config_data[field_name] = prompt_for_config(
|
|
field_type,
|
|
existing_value,
|
|
)
|
|
else:
|
|
prompt = f"Enter value for {field_name}"
|
|
if existing_value is not None:
|
|
prompt += f" (existing: {existing_value})"
|
|
elif default_value is not None:
|
|
prompt += f" (default: {default_value})"
|
|
if is_optional(field_type):
|
|
prompt += " (optional)"
|
|
elif is_required:
|
|
prompt += " (required)"
|
|
prompt += ": "
|
|
|
|
while True:
|
|
user_input = input(prompt)
|
|
if user_input == "":
|
|
if default_value is not None:
|
|
config_data[field_name] = default_value
|
|
break
|
|
elif is_optional(field_type) or not is_required:
|
|
config_data[field_name] = None
|
|
break
|
|
else:
|
|
print("This field is required. Please provide a value.")
|
|
continue
|
|
else:
|
|
try:
|
|
# Handle Optional types
|
|
if is_optional(field_type):
|
|
if user_input.lower() == "none":
|
|
value = None
|
|
else:
|
|
field_type = get_non_none_type(field_type)
|
|
value = user_input
|
|
|
|
# Handle List of primitives
|
|
elif is_list_of_primitives(field_type):
|
|
try:
|
|
value = json.loads(user_input)
|
|
if not isinstance(value, list):
|
|
raise ValueError(
|
|
"Input must be a JSON-encoded list"
|
|
)
|
|
element_type = get_args(field_type)[0]
|
|
value = [element_type(item) for item in value]
|
|
|
|
except json.JSONDecodeError:
|
|
print(
|
|
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
|
|
)
|
|
continue
|
|
except ValueError as e:
|
|
print(f"{str(e)}")
|
|
continue
|
|
|
|
elif get_origin(field_type) is dict:
|
|
try:
|
|
value = json.loads(user_input)
|
|
if not isinstance(value, dict):
|
|
raise ValueError(
|
|
"Input must be a JSON-encoded dictionary"
|
|
)
|
|
|
|
except json.JSONDecodeError:
|
|
print(
|
|
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
|
)
|
|
continue
|
|
|
|
# Convert the input to the correct type
|
|
elif inspect.isclass(field_type) and issubclass(
|
|
field_type, BaseModel
|
|
):
|
|
# For nested BaseModels, we assume a dictionary-like string input
|
|
import ast
|
|
|
|
value = field_type(**ast.literal_eval(user_input))
|
|
else:
|
|
value = field_type(user_input)
|
|
|
|
except ValueError:
|
|
print(
|
|
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
|
)
|
|
continue
|
|
|
|
try:
|
|
# Validate the field using our manual validation function
|
|
validated_value = manually_validate_field(
|
|
config_type, field_name, value
|
|
)
|
|
config_data[field_name] = validated_value
|
|
break
|
|
except ValueError as e:
|
|
print(f"Validation error: {str(e)}")
|
|
|
|
return config_type(**config_data)
|