forked from phoenix-oss/llama-stack-mirror
Adding safety adapter for Together
This commit is contained in:
parent
0d2eb3bd25
commit
b85d675c6f
8 changed files with 188 additions and 23 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
|
@ -17,8 +17,8 @@ def get_request_provider_data() -> Any:
|
|||
return getattr(_THREAD_LOCAL, "provider_data", None)
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]):
|
||||
if not validator_class:
|
||||
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
|
||||
if not validator_classes:
|
||||
return
|
||||
|
||||
keys = [
|
||||
|
@ -39,11 +39,12 @@ def set_request_provider_data(headers: Dict[str, str], validator_class: Optional
|
|||
print("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
except Exception as e:
|
||||
print("Error parsing provider data", e)
|
||||
return
|
||||
|
||||
_THREAD_LOCAL.provider_data = provider_data
|
||||
for validator_class in validator_classes:
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
if provider_data:
|
||||
_THREAD_LOCAL.provider_data = provider_data
|
||||
return
|
||||
except Exception as e:
|
||||
print("Error parsing provider data", e)
|
||||
|
|
|
@ -15,6 +15,7 @@ from collections.abc import (
|
|||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from ssl import SSLError
|
||||
from typing import (
|
||||
Any,
|
||||
|
@ -88,7 +89,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||
)
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> HTTPException:
|
||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||
if isinstance(exc, ValidationError):
|
||||
exc = RequestValidationError(exc.raw_errors)
|
||||
|
||||
|
@ -207,7 +208,7 @@ def create_dynamic_passthrough(
|
|||
|
||||
|
||||
def create_dynamic_typed_route(
|
||||
func: Any, method: str, provider_data_validator: Optional[str]
|
||||
func: Any, method: str, provider_data_validators: List[str]
|
||||
):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
@ -223,7 +224,7 @@ def create_dynamic_typed_route(
|
|||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
set_request_provider_data(request.headers, provider_data_validators)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
|
@ -254,7 +255,7 @@ def create_dynamic_typed_route(
|
|||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
set_request_provider_data(request.headers, provider_data_validators)
|
||||
|
||||
try:
|
||||
return (
|
||||
|
@ -415,6 +416,15 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
# Health check is added to enable deploying the docker container image on Kubernetes which require
|
||||
# a health check that can return 200 for readiness and liveness check
|
||||
class HealthCheck(BaseModel):
|
||||
status: str = "OK"
|
||||
|
||||
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
|
||||
async def healthcheck():
|
||||
return HealthCheck(status="OK")
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
@ -454,15 +464,22 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
)
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
validators = []
|
||||
if isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
inner_specs = specs[provider_spec.routing_table_api].inner_specs
|
||||
for spec in inner_specs:
|
||||
if spec.provider_data_validator:
|
||||
validators.append(spec.provider_data_validator)
|
||||
elif not isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
if provider_spec.provider_data_validator:
|
||||
validators.append(provider_spec.provider_data_validator)
|
||||
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
(
|
||||
provider_spec.provider_data_validator
|
||||
if not isinstance(provider_spec, RoutingTableProviderSpec)
|
||||
else None
|
||||
),
|
||||
validators,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference: remote::together
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
safety: remote::together
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue