Adding safety adapter for Together

This commit is contained in:
Yogish Baliga 2024-09-20 09:35:01 -07:00 committed by Ashwin Bharambe
parent 0d2eb3bd25
commit b85d675c6f
8 changed files with 188 additions and 23 deletions

View file

@ -15,6 +15,7 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from http import HTTPStatus
from ssl import SSLError
from typing import (
Any,
@ -88,7 +89,7 @@ async def global_exception_handler(request: Request, exc: Exception):
)
def translate_exception(exc: Exception) -> HTTPException:
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.raw_errors)
@ -207,7 +208,7 @@ def create_dynamic_passthrough(
def create_dynamic_typed_route(
func: Any, method: str, provider_data_validator: Optional[str]
func: Any, method: str, provider_data_validators: List[str]
):
hints = get_type_hints(func)
response_model = hints.get("return")
@ -223,7 +224,7 @@ def create_dynamic_typed_route(
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
set_request_provider_data(request.headers, provider_data_validators)
async def sse_generator(event_gen):
try:
@ -254,7 +255,7 @@ def create_dynamic_typed_route(
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
set_request_provider_data(request.headers, provider_data_validators)
try:
return (
@ -415,6 +416,15 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
# Health check is added to enable deploying the docker container image on Kubernetes which require
# a health check that can return 200 for readiness and liveness check
class HealthCheck(BaseModel):
status: str = "OK"
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
async def healthcheck():
return HealthCheck(status="OK")
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
@ -454,15 +464,22 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
)
impl_method = getattr(impl, endpoint.name)
validators = []
if isinstance(provider_spec, AutoRoutedProviderSpec):
inner_specs = specs[provider_spec.routing_table_api].inner_specs
for spec in inner_specs:
if spec.provider_data_validator:
validators.append(spec.provider_data_validator)
elif not isinstance(provider_spec, RoutingTableProviderSpec):
if provider_spec.provider_data_validator:
validators.append(provider_spec.provider_data_validator)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
(
provider_spec.provider_data_validator
if not isinstance(provider_spec, RoutingTableProviderSpec)
else None
),
validators,
)
)