mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
feat: Implement FastAPI router system (#4191)
# What does this PR do? This commit introduces a new FastAPI router-based system for defining API endpoints, enabling a migration path away from the legacy @webmethod decorator system. The implementation includes router infrastructure, migration of the Batches API as the first example, and updates to server, OpenAPI generation, and inspection systems to support both routing approaches. The router infrastructure consists of a router registry system that allows APIs to register FastAPI router factories, which are then automatically discovered and included in the server application. Standard error responses are centralized in router_utils to ensure consistent OpenAPI specification generation with proper $ref references to component responses. The Batches API has been migrated to demonstrate the new pattern. The protocol definition and models remain in llama_stack_api/batches, maintaining clear separation between API contracts and server implementation. The FastAPI router implementation lives in llama_stack/core/server/routers/batches, following the established pattern where API contracts are defined in llama_stack_api and server routing logic lives in llama_stack/core/server. The server now checks for registered routers before falling back to the legacy webmethod-based route discovery, ensuring backward compatibility during the migration period. The OpenAPI generator has been updated to handle both router-based and webmethod-based routes, correctly extracting metadata from FastAPI route decorators and Pydantic Field descriptions. The inspect endpoint now includes routes from both systems, with proper filtering for deprecated routes and API levels. Response descriptions are now explicitly defined in router decorators, ensuring the generated OpenAPI specification matches the previous format. Error responses use $ref references to component responses (BadRequest400, TooManyRequests429, etc.) as required by the specification. This is neat and will allow us to remove a lot of boiler plate code from our generator once the migration is done. This implementation provides a foundation for incrementally migrating other APIs to the router system while maintaining full backward compatibility with existing webmethod-based APIs. Closes: https://github.com/llamastack/llama-stack/issues/4188 ## Test Plan CI, the server should start, same routes should be visible. ``` curl http://localhost:8321/v1/inspect/routes | jq '.data[] | select(.route | contains("batches"))' ``` Also: ``` uv run pytest tests/integration/batches/ -vv --stack-config=http://localhost:8321 ================================================== test session starts ================================================== platform darwin -- Python 3.12.8, pytest-8.4.2, pluggy-1.6.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.12.8', 'Platform': 'macOS-26.0.1-arm64-arm-64bit', 'Packages': {'pytest': '8.4.2', 'pluggy': '1.6.0'}, 'Plugins': {'anyio': '4.9.0', 'html': '4.1.1', 'socket': '0.7.0', 'asyncio': '1.1.0', 'json-report': '1.5.0', 'timeout': '2.4.0', 'metadata': '3.1.1', 'cov': '6.2.1', 'nbval': '0.11.0'}} rootdir: /Users/leseb/Documents/AI/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0, html-4.1.1, socket-0.7.0, asyncio-1.1.0, json-report-1.5.0, timeout-2.4.0, metadata-3.1.1, cov-6.2.1, nbval-0.11.0 asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 24 items tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_creation_and_retrieval[None] SKIPPED [ 4%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_listing[None] SKIPPED [ 8%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_immediate_cancellation[None] SKIPPED [ 12%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_chat_completions[None] SKIPPED [ 16%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_completions[None] SKIPPED [ 20%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_endpoint[None] SKIPPED [ 25%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_completed[None] SKIPPED [ 29%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_missing_required_fields[None] SKIPPED [ 33%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_completion_window[None] SKIPPED [ 37%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_streaming_not_supported[None] SKIPPED [ 41%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_mixed_streaming_requests[None] SKIPPED [ 45%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_endpoint_mismatch[None] SKIPPED [ 50%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_missing_required_body_fields[None] SKIPPED [ 54%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_metadata_types[None] SKIPPED [ 58%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_embeddings[None] SKIPPED [ 62%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_nonexistent_file_id PASSED [ 66%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_malformed_jsonl PASSED [ 70%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[empty] XFAIL [ 75%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[malformed] XFAIL [ 79%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_retrieve_nonexistent PASSED [ 83%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_nonexistent PASSED [ 87%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_error_handling_invalid_model PASSED [ 91%] tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotent_batch_creation_successful PASSED [ 95%] tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotency_conflict_with_different_params PASSED [100%] ================================================= slowest 10 durations ================================================== 1.01s call tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotent_batch_creation_successful 0.21s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_nonexistent_file_id 0.17s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_malformed_jsonl 0.12s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_error_handling_invalid_model 0.05s setup tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_creation_and_retrieval[None] 0.02s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[empty] 0.01s call tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotency_conflict_with_different_params 0.01s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[malformed] 0.01s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_retrieve_nonexistent 0.00s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_nonexistent ======================================= 7 passed, 15 skipped, 2 xfailed in 1.78s ======================================== ``` --------- Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
4237eb4aaa
commit
7f43051a63
22 changed files with 1095 additions and 248 deletions
|
|
@ -37,7 +37,7 @@ paths:
|
|||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: List Batches
|
||||
summary: List all batches for the current user.
|
||||
description: List all batches for the current user.
|
||||
operationId: list_batches_v1_batches_get
|
||||
parameters:
|
||||
|
|
@ -48,14 +48,18 @@ paths:
|
|||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
title: After
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
- name: limit
|
||||
in: query
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
default: 20
|
||||
title: Limit
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
|
|
@ -76,9 +80,11 @@ paths:
|
|||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
'409':
|
||||
description: 'Conflict: The idempotency key was previously used with different parameters.'
|
||||
tags:
|
||||
- Batches
|
||||
summary: Create Batch
|
||||
summary: Create a new batch for processing multiple API requests.
|
||||
description: Create a new batch for processing multiple API requests.
|
||||
operationId: create_batch_v1_batches_post
|
||||
requestBody:
|
||||
|
|
@ -97,20 +103,20 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
description: Bad Request
|
||||
'429':
|
||||
description: Too Many Requests
|
||||
$ref: '#/components/responses/TooManyRequests429'
|
||||
description: Too Many Requests
|
||||
'500':
|
||||
description: Internal Server Error
|
||||
$ref: '#/components/responses/InternalServerError500'
|
||||
description: Internal Server Error
|
||||
default:
|
||||
description: Default Response
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: Retrieve Batch
|
||||
summary: Retrieve information about a specific batch.
|
||||
description: Retrieve information about a specific batch.
|
||||
operationId: retrieve_batch_v1_batches__batch_id__get
|
||||
parameters:
|
||||
|
|
@ -119,7 +125,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: 'Path parameter: batch_id'
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to retrieve.
|
||||
/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -130,20 +138,20 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
description: Bad Request
|
||||
'429':
|
||||
description: Too Many Requests
|
||||
$ref: '#/components/responses/TooManyRequests429'
|
||||
description: Too Many Requests
|
||||
'500':
|
||||
description: Internal Server Error
|
||||
$ref: '#/components/responses/InternalServerError500'
|
||||
description: Internal Server Error
|
||||
default:
|
||||
description: Default Response
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: Cancel Batch
|
||||
summary: Cancel a batch that is in progress.
|
||||
description: Cancel a batch that is in progress.
|
||||
operationId: cancel_batch_v1_batches__batch_id__cancel_post
|
||||
parameters:
|
||||
|
|
@ -152,7 +160,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: 'Path parameter: batch_id'
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to cancel.
|
||||
/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3956,29 +3966,35 @@ components:
|
|||
input_file_id:
|
||||
type: string
|
||||
title: Input File Id
|
||||
description: The ID of an uploaded file containing requests for the batch.
|
||||
endpoint:
|
||||
type: string
|
||||
title: Endpoint
|
||||
description: The endpoint to be used for all requests in the batch.
|
||||
completion_window:
|
||||
type: string
|
||||
const: 24h
|
||||
title: Completion Window
|
||||
description: The time window within which the batch should be processed.
|
||||
metadata:
|
||||
anyOf:
|
||||
- additionalProperties:
|
||||
type: string
|
||||
type: object
|
||||
- type: 'null'
|
||||
description: Optional metadata for the batch.
|
||||
idempotency_key:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
type: object
|
||||
required:
|
||||
- input_file_id
|
||||
- endpoint
|
||||
- completion_window
|
||||
title: CreateBatchRequest
|
||||
description: Request model for creating a batch.
|
||||
Batch:
|
||||
properties:
|
||||
id:
|
||||
|
|
@ -12563,6 +12579,44 @@ components:
|
|||
- query
|
||||
title: VectorStoreSearchRequest
|
||||
type: object
|
||||
ListBatchesRequest:
|
||||
description: Request model for listing batches.
|
||||
properties:
|
||||
after:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
nullable: true
|
||||
limit:
|
||||
default: 20
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
title: Limit
|
||||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
44
docs/static/deprecated-llama-stack-spec.yaml
vendored
44
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -950,29 +950,35 @@ components:
|
|||
input_file_id:
|
||||
type: string
|
||||
title: Input File Id
|
||||
description: The ID of an uploaded file containing requests for the batch.
|
||||
endpoint:
|
||||
type: string
|
||||
title: Endpoint
|
||||
description: The endpoint to be used for all requests in the batch.
|
||||
completion_window:
|
||||
type: string
|
||||
const: 24h
|
||||
title: Completion Window
|
||||
description: The time window within which the batch should be processed.
|
||||
metadata:
|
||||
anyOf:
|
||||
- additionalProperties:
|
||||
type: string
|
||||
type: object
|
||||
- type: 'null'
|
||||
description: Optional metadata for the batch.
|
||||
idempotency_key:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
type: object
|
||||
required:
|
||||
- input_file_id
|
||||
- endpoint
|
||||
- completion_window
|
||||
title: CreateBatchRequest
|
||||
description: Request model for creating a batch.
|
||||
Batch:
|
||||
properties:
|
||||
id:
|
||||
|
|
@ -9557,6 +9563,44 @@ components:
|
|||
- query
|
||||
title: VectorStoreSearchRequest
|
||||
type: object
|
||||
ListBatchesRequest:
|
||||
description: Request model for listing batches.
|
||||
properties:
|
||||
after:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
nullable: true
|
||||
limit:
|
||||
default: 20
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
title: Limit
|
||||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
72
docs/static/experimental-llama-stack-spec.yaml
vendored
72
docs/static/experimental-llama-stack-spec.yaml
vendored
|
|
@ -688,6 +688,40 @@ components:
|
|||
- data
|
||||
title: ListBatchesResponse
|
||||
description: Response containing a list of batch objects.
|
||||
CreateBatchRequest:
|
||||
properties:
|
||||
input_file_id:
|
||||
type: string
|
||||
title: Input File Id
|
||||
description: The ID of an uploaded file containing requests for the batch.
|
||||
endpoint:
|
||||
type: string
|
||||
title: Endpoint
|
||||
description: The endpoint to be used for all requests in the batch.
|
||||
completion_window:
|
||||
type: string
|
||||
const: 24h
|
||||
title: Completion Window
|
||||
description: The time window within which the batch should be processed.
|
||||
metadata:
|
||||
anyOf:
|
||||
- additionalProperties:
|
||||
type: string
|
||||
type: object
|
||||
- type: 'null'
|
||||
description: Optional metadata for the batch.
|
||||
idempotency_key:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
type: object
|
||||
required:
|
||||
- input_file_id
|
||||
- endpoint
|
||||
- completion_window
|
||||
title: CreateBatchRequest
|
||||
description: Request model for creating a batch.
|
||||
Batch:
|
||||
properties:
|
||||
id:
|
||||
|
|
@ -8323,6 +8357,44 @@ components:
|
|||
- query
|
||||
title: VectorStoreSearchRequest
|
||||
type: object
|
||||
ListBatchesRequest:
|
||||
description: Request model for listing batches.
|
||||
properties:
|
||||
after:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
nullable: true
|
||||
limit:
|
||||
default: 20
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
title: Limit
|
||||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
82
docs/static/llama-stack-spec.yaml
vendored
82
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -35,7 +35,7 @@ paths:
|
|||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: List Batches
|
||||
summary: List all batches for the current user.
|
||||
description: List all batches for the current user.
|
||||
operationId: list_batches_v1_batches_get
|
||||
parameters:
|
||||
|
|
@ -46,14 +46,18 @@ paths:
|
|||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
title: After
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
- name: limit
|
||||
in: query
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
default: 20
|
||||
title: Limit
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
|
|
@ -74,9 +78,11 @@ paths:
|
|||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
'409':
|
||||
description: 'Conflict: The idempotency key was previously used with different parameters.'
|
||||
tags:
|
||||
- Batches
|
||||
summary: Create Batch
|
||||
summary: Create a new batch for processing multiple API requests.
|
||||
description: Create a new batch for processing multiple API requests.
|
||||
operationId: create_batch_v1_batches_post
|
||||
requestBody:
|
||||
|
|
@ -95,20 +101,20 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
description: Bad Request
|
||||
'429':
|
||||
description: Too Many Requests
|
||||
$ref: '#/components/responses/TooManyRequests429'
|
||||
description: Too Many Requests
|
||||
'500':
|
||||
description: Internal Server Error
|
||||
$ref: '#/components/responses/InternalServerError500'
|
||||
description: Internal Server Error
|
||||
default:
|
||||
description: Default Response
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: Retrieve Batch
|
||||
summary: Retrieve information about a specific batch.
|
||||
description: Retrieve information about a specific batch.
|
||||
operationId: retrieve_batch_v1_batches__batch_id__get
|
||||
parameters:
|
||||
|
|
@ -117,7 +123,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: 'Path parameter: batch_id'
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to retrieve.
|
||||
/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -128,20 +136,20 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
description: Bad Request
|
||||
'429':
|
||||
description: Too Many Requests
|
||||
$ref: '#/components/responses/TooManyRequests429'
|
||||
description: Too Many Requests
|
||||
'500':
|
||||
description: Internal Server Error
|
||||
$ref: '#/components/responses/InternalServerError500'
|
||||
description: Internal Server Error
|
||||
default:
|
||||
description: Default Response
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: Cancel Batch
|
||||
summary: Cancel a batch that is in progress.
|
||||
description: Cancel a batch that is in progress.
|
||||
operationId: cancel_batch_v1_batches__batch_id__cancel_post
|
||||
parameters:
|
||||
|
|
@ -150,7 +158,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: 'Path parameter: batch_id'
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to cancel.
|
||||
/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2761,29 +2771,35 @@ components:
|
|||
input_file_id:
|
||||
type: string
|
||||
title: Input File Id
|
||||
description: The ID of an uploaded file containing requests for the batch.
|
||||
endpoint:
|
||||
type: string
|
||||
title: Endpoint
|
||||
description: The endpoint to be used for all requests in the batch.
|
||||
completion_window:
|
||||
type: string
|
||||
const: 24h
|
||||
title: Completion Window
|
||||
description: The time window within which the batch should be processed.
|
||||
metadata:
|
||||
anyOf:
|
||||
- additionalProperties:
|
||||
type: string
|
||||
type: object
|
||||
- type: 'null'
|
||||
description: Optional metadata for the batch.
|
||||
idempotency_key:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
type: object
|
||||
required:
|
||||
- input_file_id
|
||||
- endpoint
|
||||
- completion_window
|
||||
title: CreateBatchRequest
|
||||
description: Request model for creating a batch.
|
||||
Batch:
|
||||
properties:
|
||||
id:
|
||||
|
|
@ -10999,6 +11015,44 @@ components:
|
|||
- query
|
||||
title: VectorStoreSearchRequest
|
||||
type: object
|
||||
ListBatchesRequest:
|
||||
description: Request model for listing batches.
|
||||
properties:
|
||||
after:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
nullable: true
|
||||
limit:
|
||||
default: 20
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
title: Limit
|
||||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
82
docs/static/stainless-llama-stack-spec.yaml
vendored
82
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -37,7 +37,7 @@ paths:
|
|||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: List Batches
|
||||
summary: List all batches for the current user.
|
||||
description: List all batches for the current user.
|
||||
operationId: list_batches_v1_batches_get
|
||||
parameters:
|
||||
|
|
@ -48,14 +48,18 @@ paths:
|
|||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
title: After
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
- name: limit
|
||||
in: query
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
default: 20
|
||||
title: Limit
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
|
|
@ -76,9 +80,11 @@ paths:
|
|||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
'409':
|
||||
description: 'Conflict: The idempotency key was previously used with different parameters.'
|
||||
tags:
|
||||
- Batches
|
||||
summary: Create Batch
|
||||
summary: Create a new batch for processing multiple API requests.
|
||||
description: Create a new batch for processing multiple API requests.
|
||||
operationId: create_batch_v1_batches_post
|
||||
requestBody:
|
||||
|
|
@ -97,20 +103,20 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
description: Bad Request
|
||||
'429':
|
||||
description: Too Many Requests
|
||||
$ref: '#/components/responses/TooManyRequests429'
|
||||
description: Too Many Requests
|
||||
'500':
|
||||
description: Internal Server Error
|
||||
$ref: '#/components/responses/InternalServerError500'
|
||||
description: Internal Server Error
|
||||
default:
|
||||
description: Default Response
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: Retrieve Batch
|
||||
summary: Retrieve information about a specific batch.
|
||||
description: Retrieve information about a specific batch.
|
||||
operationId: retrieve_batch_v1_batches__batch_id__get
|
||||
parameters:
|
||||
|
|
@ -119,7 +125,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: 'Path parameter: batch_id'
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to retrieve.
|
||||
/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -130,20 +138,20 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
description: Bad Request
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
description: Bad Request
|
||||
'429':
|
||||
description: Too Many Requests
|
||||
$ref: '#/components/responses/TooManyRequests429'
|
||||
description: Too Many Requests
|
||||
'500':
|
||||
description: Internal Server Error
|
||||
$ref: '#/components/responses/InternalServerError500'
|
||||
description: Internal Server Error
|
||||
default:
|
||||
description: Default Response
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
description: Default Response
|
||||
tags:
|
||||
- Batches
|
||||
summary: Cancel Batch
|
||||
summary: Cancel a batch that is in progress.
|
||||
description: Cancel a batch that is in progress.
|
||||
operationId: cancel_batch_v1_batches__batch_id__cancel_post
|
||||
parameters:
|
||||
|
|
@ -152,7 +160,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: 'Path parameter: batch_id'
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to cancel.
|
||||
/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3956,29 +3966,35 @@ components:
|
|||
input_file_id:
|
||||
type: string
|
||||
title: Input File Id
|
||||
description: The ID of an uploaded file containing requests for the batch.
|
||||
endpoint:
|
||||
type: string
|
||||
title: Endpoint
|
||||
description: The endpoint to be used for all requests in the batch.
|
||||
completion_window:
|
||||
type: string
|
||||
const: 24h
|
||||
title: Completion Window
|
||||
description: The time window within which the batch should be processed.
|
||||
metadata:
|
||||
anyOf:
|
||||
- additionalProperties:
|
||||
type: string
|
||||
type: object
|
||||
- type: 'null'
|
||||
description: Optional metadata for the batch.
|
||||
idempotency_key:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
type: object
|
||||
required:
|
||||
- input_file_id
|
||||
- endpoint
|
||||
- completion_window
|
||||
title: CreateBatchRequest
|
||||
description: Request model for creating a batch.
|
||||
Batch:
|
||||
properties:
|
||||
id:
|
||||
|
|
@ -12563,6 +12579,44 @@ components:
|
|||
- query
|
||||
title: VectorStoreSearchRequest
|
||||
type: object
|
||||
ListBatchesRequest:
|
||||
description: Request model for listing batches.
|
||||
properties:
|
||||
after:
|
||||
anyOf:
|
||||
- type: string
|
||||
- type: 'null'
|
||||
description: Optional cursor for pagination. Returns batches after this ID.
|
||||
nullable: true
|
||||
limit:
|
||||
default: 20
|
||||
description: Maximum number of batches to return. Defaults to 20.
|
||||
title: Limit
|
||||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from typing import Any
|
|||
from fastapi import FastAPI
|
||||
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router, has_router
|
||||
from llama_stack_api import Api
|
||||
|
||||
from .state import _protocol_methods_cache
|
||||
|
|
@ -64,7 +65,8 @@ def _get_protocol_method(api: Api, method_name: str) -> Any | None:
|
|||
def create_llama_stack_app() -> FastAPI:
|
||||
"""
|
||||
Create a FastAPI app that represents the Llama Stack API.
|
||||
This uses the existing route discovery system to automatically find all routes.
|
||||
This uses both router-based routes (for migrated APIs) and the existing
|
||||
route discovery system for legacy webmethod-based routes.
|
||||
"""
|
||||
app = FastAPI(
|
||||
title="Llama Stack API",
|
||||
|
|
@ -75,15 +77,27 @@ def create_llama_stack_app() -> FastAPI:
|
|||
],
|
||||
)
|
||||
|
||||
# Get all API routes
|
||||
# Include routers for APIs that have them
|
||||
protocols = api_protocol_map()
|
||||
for api in protocols.keys():
|
||||
# For OpenAPI generation, we don't need a real implementation
|
||||
if not has_router(api):
|
||||
continue
|
||||
app.include_router(build_fastapi_router(api, None))
|
||||
|
||||
# Get all API routes (for legacy webmethod-based routes)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
||||
api_routes = get_all_api_routes()
|
||||
|
||||
# Create FastAPI routes from the discovered routes
|
||||
# Create FastAPI routes from the discovered routes (skip APIs that have routers)
|
||||
from . import endpoints
|
||||
|
||||
for api, routes in api_routes.items():
|
||||
# Skip APIs that have routers - they're already included above
|
||||
if has_router(api):
|
||||
continue
|
||||
|
||||
for route, webmethod in routes:
|
||||
# Convert the route to a FastAPI endpoint
|
||||
endpoints._create_fastapi_endpoint(app, route, webmethod, api)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,14 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.fastapi_router_registry import (
|
||||
_ROUTER_FACTORIES,
|
||||
build_fastapi_router,
|
||||
get_router_routes,
|
||||
)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
HealthInfo,
|
||||
HealthStatus,
|
||||
Inspect,
|
||||
|
|
@ -43,6 +49,7 @@ class DistributionInspectImpl(Inspect):
|
|||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
# TODO: remove this once we've migrated all APIs to FastAPI routers
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated APIs
|
||||
|
|
@ -54,10 +61,62 @@ class DistributionInspectImpl(Inspect):
|
|||
# Filter by API level (non-deprecated routes only)
|
||||
return not webmethod.deprecated and webmethod.level == api_filter
|
||||
|
||||
# Helper function to get provider types for an API
|
||||
def _get_provider_types(api: Api) -> list[str]:
|
||||
if api.value in ["providers", "inspect"]:
|
||||
return [] # These APIs don't have "real" providers they're internal to the stack
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
return [p.provider_type for p in providers] if providers else []
|
||||
|
||||
# Helper function to determine if a router route should be included based on api_filter
|
||||
def _should_include_router_route(route, router_prefix: str | None) -> bool:
|
||||
"""Check if a router-based route should be included based on api_filter."""
|
||||
# Check deprecated status
|
||||
route_deprecated = getattr(route, "deprecated", False) or False
|
||||
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated routes
|
||||
return not route_deprecated
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return route_deprecated
|
||||
else:
|
||||
# Filter by API level (non-deprecated routes only)
|
||||
# Extract level from router prefix (e.g., "/v1" -> "v1")
|
||||
if router_prefix:
|
||||
prefix_level = router_prefix.lstrip("/")
|
||||
return not route_deprecated and prefix_level == api_filter
|
||||
return not route_deprecated
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
|
||||
# Process routes from APIs with FastAPI routers
|
||||
for api_name in _ROUTER_FACTORIES.keys():
|
||||
api = Api(api_name)
|
||||
router = build_fastapi_router(api, None) # we don't need the impl here, just the routes
|
||||
if router:
|
||||
router_routes = get_router_routes(router)
|
||||
for route in router_routes:
|
||||
if _should_include_router_route(route, router.prefix):
|
||||
if route.methods is not None:
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if available_methods:
|
||||
ret.append(
|
||||
RouteInfo(
|
||||
route=route.path,
|
||||
method=available_methods[0],
|
||||
provider_types=_get_provider_types(api),
|
||||
)
|
||||
)
|
||||
|
||||
# Process routes from legacy webmethod-based APIs
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Skip APIs that have routers (already processed above)
|
||||
if api.value in _ROUTER_FACTORIES:
|
||||
continue
|
||||
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
ret.extend(
|
||||
|
|
|
|||
97
src/llama_stack/core/server/fastapi_router_registry.py
Normal file
97
src/llama_stack/core/server/fastapi_router_registry.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Router utilities for FastAPI routers.
|
||||
|
||||
This module provides utilities to create FastAPI routers from API packages.
|
||||
APIs with routers are explicitly listed here.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.routing import APIRoute
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack_api import batches
|
||||
|
||||
# Router factories for APIs that have FastAPI routers
|
||||
# Add new APIs here as they are migrated to the router system
|
||||
from llama_stack_api.datatypes import Api
|
||||
|
||||
_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
|
||||
"batches": batches.fastapi_routes.create_router,
|
||||
}
|
||||
|
||||
|
||||
def has_router(api: "Api") -> bool:
|
||||
"""Check if an API has a router factory.
|
||||
|
||||
Args:
|
||||
api: The API enum value
|
||||
|
||||
Returns:
|
||||
True if the API has a router factory, False otherwise
|
||||
"""
|
||||
return api.value in _ROUTER_FACTORIES
|
||||
|
||||
|
||||
def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None:
|
||||
"""Build a router for an API by combining its router factory with the implementation.
|
||||
|
||||
Args:
|
||||
api: The API enum value
|
||||
impl: The implementation instance for the API
|
||||
|
||||
Returns:
|
||||
APIRouter if the API has a router factory, None otherwise
|
||||
"""
|
||||
router_factory = _ROUTER_FACTORIES.get(api.value)
|
||||
if router_factory is None:
|
||||
return None
|
||||
|
||||
# cast is safe here: all router factories in API packages are required to return APIRouter.
|
||||
# If a router factory returns the wrong type, it will fail at runtime when
|
||||
# app.include_router(router) is called
|
||||
return cast(APIRouter, router_factory(impl))
|
||||
|
||||
|
||||
def get_router_routes(router: APIRouter) -> list[Route]:
|
||||
"""Extract routes from a FastAPI router.
|
||||
|
||||
Args:
|
||||
router: The FastAPI router to extract routes from
|
||||
|
||||
Returns:
|
||||
List of Route objects from the router
|
||||
"""
|
||||
routes = []
|
||||
|
||||
for route in router.routes:
|
||||
# FastAPI routers use APIRoute objects, which have path and methods attributes
|
||||
if isinstance(route, APIRoute):
|
||||
# Combine router prefix with route path
|
||||
routes.append(
|
||||
Route(
|
||||
path=route.path,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
endpoint=route.endpoint,
|
||||
)
|
||||
)
|
||||
elif isinstance(route, Route):
|
||||
# Fallback for regular Starlette Route objects
|
||||
routes.append(
|
||||
Route(
|
||||
path=route.path,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
endpoint=route.endpoint,
|
||||
)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
|
@ -26,6 +26,18 @@ RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
|||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
"""Get all API routes from webmethod-based protocols.
|
||||
|
||||
This function only returns routes from APIs that use the legacy @webmethod
|
||||
decorator system. For APIs that have been migrated to FastAPI routers,
|
||||
use the router registry (fastapi_router_registry.has_router() and fastapi_router_registry.build_fastapi_router()).
|
||||
|
||||
Args:
|
||||
external_apis: Optional dictionary of external API specifications
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API to list of (Route, WebMethod) tuples
|
||||
"""
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from llama_stack.core.request_headers import (
|
|||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
|
|
@ -84,7 +85,7 @@ def create_sse_event(data: Any) -> str:
|
|||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
|
@ -454,15 +455,22 @@ def create_app() -> StackApp:
|
|||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
apis_to_serve.add("conversations")
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
try:
|
||||
# Try to discover and use a router factory from the API package
|
||||
impl = impls[api]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||
router = build_fastapi_router(api, impl)
|
||||
if router:
|
||||
app.include_router(router)
|
||||
logger.debug(f"Registered FastAPIrouter for {api} API")
|
||||
continue
|
||||
|
||||
# Fall back to old webmethod-based route discovery until the migration is complete
|
||||
impl = impls[api]
|
||||
|
||||
routes = all_routes[api]
|
||||
for route, _ in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
|
|
@ -488,7 +496,15 @@ def create_app() -> StackApp:
|
|||
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
# Register specific exception handlers before the generic Exception handler
|
||||
# This prevents the re-raising behavior that causes connection resets
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(ConflictError)(global_exception_handler)
|
||||
app.exception_handler(ResourceNotFoundError)(global_exception_handler)
|
||||
app.exception_handler(AuthenticationRequiredError)(global_exception_handler)
|
||||
app.exception_handler(AccessDeniedError)(global_exception_handler)
|
||||
app.exception_handler(BadRequestError)(global_exception_handler)
|
||||
# Generic Exception handler should be last
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import json
|
|||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -38,6 +38,12 @@ from llama_stack_api import (
|
|||
OpenAIUserMessageParam,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
|
|
@ -140,11 +146,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
request: CreateBatchRequest,
|
||||
) -> BatchObject:
|
||||
"""
|
||||
Create a new batch for processing multiple API requests.
|
||||
|
|
@ -185,14 +187,14 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
if request.completion_window != "24h":
|
||||
raise ValueError(
|
||||
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||
|
|
@ -200,22 +202,22 @@ class ReferenceBatchesImpl(Batches):
|
|||
# For idempotent requests, use the idempotency key for the batch ID
|
||||
# This ensures the same key always maps to the same batch ID,
|
||||
# allowing us to detect parameter conflicts
|
||||
if idempotency_key is not None:
|
||||
hash_input = idempotency_key.encode("utf-8")
|
||||
if request.idempotency_key is not None:
|
||||
hash_input = request.idempotency_key.encode("utf-8")
|
||||
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
|
||||
batch_id = f"batch_{hash_digest}"
|
||||
|
||||
try:
|
||||
existing_batch = await self.retrieve_batch(batch_id)
|
||||
existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
if (
|
||||
existing_batch.input_file_id != input_file_id
|
||||
or existing_batch.endpoint != endpoint
|
||||
or existing_batch.completion_window != completion_window
|
||||
or existing_batch.metadata != metadata
|
||||
existing_batch.input_file_id != request.input_file_id
|
||||
or existing_batch.endpoint != request.endpoint
|
||||
or existing_batch.completion_window != request.completion_window
|
||||
or existing_batch.metadata != request.metadata
|
||||
):
|
||||
raise ConflictError(
|
||||
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
|
||||
f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. "
|
||||
"Either use a new idempotency key or ensure all parameters match the original request."
|
||||
)
|
||||
|
||||
|
|
@ -230,12 +232,12 @@ class ReferenceBatchesImpl(Batches):
|
|||
batch = BatchObject(
|
||||
id=batch_id,
|
||||
object="batch",
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
completion_window=completion_window,
|
||||
endpoint=request.endpoint,
|
||||
input_file_id=request.input_file_id,
|
||||
completion_window=request.completion_window,
|
||||
status="validating",
|
||||
created_at=current_time,
|
||||
metadata=metadata,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||
|
|
@ -247,28 +249,27 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
return batch
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject:
|
||||
"""Cancel a batch that is in progress."""
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
|
||||
|
||||
if batch.status in ["cancelled", "cancelling"]:
|
||||
return batch
|
||||
|
||||
if batch.status in ["completed", "failed", "expired"]:
|
||||
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||
raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'")
|
||||
|
||||
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
|
||||
if batch_id in self._processing_tasks:
|
||||
self._processing_tasks[batch_id].cancel()
|
||||
if request.batch_id in self._processing_tasks:
|
||||
self._processing_tasks[request.batch_id].cancel()
|
||||
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||
|
||||
return await self.retrieve_batch(batch_id)
|
||||
return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
request: ListBatchesRequest,
|
||||
) -> ListBatchesResponse:
|
||||
"""
|
||||
List all batches, eventually only for the current user.
|
||||
|
|
@ -285,14 +286,14 @@ class ReferenceBatchesImpl(Batches):
|
|||
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||
|
||||
start_idx = 0
|
||||
if after:
|
||||
if request.after:
|
||||
for i, batch in enumerate(batches):
|
||||
if batch.id == after:
|
||||
if batch.id == request.after:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
page_batches = batches[start_idx : start_idx + limit]
|
||||
has_more = (start_idx + limit) < len(batches)
|
||||
page_batches = batches[start_idx : start_idx + request.limit]
|
||||
has_more = (start_idx + request.limit) < len(batches)
|
||||
|
||||
first_id = page_batches[0].id if page_batches else None
|
||||
last_id = page_batches[-1].id if page_batches else None
|
||||
|
|
@ -304,11 +305,11 @@ class ReferenceBatchesImpl(Batches):
|
|||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject:
|
||||
"""Retrieve information about a specific batch."""
|
||||
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||
batch_data = await self.kvstore.get(f"batch:{request.batch_id}")
|
||||
if not batch_data:
|
||||
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||
raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()")
|
||||
|
||||
return BatchObject.model_validate_json(batch_data)
|
||||
|
||||
|
|
@ -316,7 +317,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
"""Update batch fields in kvstore."""
|
||||
async with self._update_batch_lock:
|
||||
try:
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||
|
|
@ -536,7 +537,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||
"""Implementation of batch processing logic."""
|
||||
errors: list[BatchError] = []
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
errors, requests = await self._validate_input(batch)
|
||||
if errors:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,15 @@ from . import common # noqa: F401
|
|||
|
||||
# Import all public API symbols
|
||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
from .batches import (
|
||||
Batches,
|
||||
BatchObject,
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from .benchmarks import (
|
||||
Benchmark,
|
||||
BenchmarkInput,
|
||||
|
|
@ -462,6 +470,9 @@ __all__ = [
|
|||
"BasicScoringFnParams",
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CancelBatchRequest",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"Benchmark",
|
||||
"BenchmarkConfig",
|
||||
"BenchmarkInput",
|
||||
|
|
@ -555,6 +566,7 @@ __all__ = [
|
|||
"LLMAsJudgeScoringFnParams",
|
||||
"LLMRAGQueryGeneratorConfig",
|
||||
"ListBatchesResponse",
|
||||
"RetrieveBatchRequest",
|
||||
"ListBenchmarksResponse",
|
||||
"ListDatasetsResponse",
|
||||
"ListModelsResponse",
|
||||
|
|
|
|||
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
:param batch_id: The ID of the batch to retrieve.
|
||||
:returns: The batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
:param batch_id: The ID of the batch to cancel.
|
||||
:returns: The updated batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""List all batches for the current user.
|
||||
|
||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||
:param limit: Number of batches to return (default 20, max 100).
|
||||
:returns: A list of batch objects.
|
||||
"""
|
||||
...
|
||||
40
src/llama_stack_api/batches/__init__.py
Normal file
40
src/llama_stack_api/batches/__init__.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Batches API protocol and models.
|
||||
|
||||
This module contains the Batches protocol definition.
|
||||
Pydantic models are defined in llama_stack_api.batches.models.
|
||||
The FastAPI router is defined in llama_stack_api.batches.fastapi_routes.
|
||||
"""
|
||||
|
||||
from openai.types import Batch as BatchObject
|
||||
|
||||
# Import fastapi_routes for router factory access
|
||||
from . import fastapi_routes
|
||||
|
||||
# Import protocol for re-export
|
||||
from .api import Batches
|
||||
|
||||
# Import models for re-export
|
||||
from .models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CancelBatchRequest",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"ListBatchesResponse",
|
||||
"RetrieveBatchRequest",
|
||||
"fastapi_routes",
|
||||
]
|
||||
53
src/llama_stack_api/batches/api.py
Normal file
53
src/llama_stack_api/batches/api.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from openai.types import Batch as BatchObject
|
||||
|
||||
from .models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
async def create_batch(
|
||||
self,
|
||||
request: CreateBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def retrieve_batch(
|
||||
self,
|
||||
request: RetrieveBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def cancel_batch(
|
||||
self,
|
||||
request: CancelBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
request: ListBatchesRequest,
|
||||
) -> ListBatchesResponse: ...
|
||||
113
src/llama_stack_api/batches/fastapi_routes.py
Normal file
113
src/llama_stack_api/batches/fastapi_routes.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""FastAPI router for the Batches API.
|
||||
|
||||
This module defines the FastAPI router for the Batches API using standard
|
||||
FastAPI route decorators. The router is defined in the API package to keep
|
||||
all API-related code together.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
from .api import Batches
|
||||
from .models import BatchObject, ListBatchesResponse
|
||||
|
||||
# Automatically generate dependency functions from Pydantic models
|
||||
# This ensures the models are the single source of truth for descriptions
|
||||
get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest)
|
||||
get_cancel_batch_request = create_path_dependency(CancelBatchRequest)
|
||||
|
||||
|
||||
# Automatically generate dependency function from Pydantic model
|
||||
# This ensures the model is the single source of truth for descriptions and defaults
|
||||
get_list_batches_request = create_query_dependency(ListBatchesRequest)
|
||||
|
||||
|
||||
def create_router(impl: Batches) -> APIRouter:
|
||||
"""Create a FastAPI router for the Batches API.
|
||||
|
||||
Args:
|
||||
impl: The Batches implementation instance
|
||||
|
||||
Returns:
|
||||
APIRouter configured for the Batches API
|
||||
"""
|
||||
router = APIRouter(
|
||||
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||
tags=["Batches"],
|
||||
responses=standard_responses,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
response_model=BatchObject,
|
||||
summary="Create a new batch for processing multiple API requests.",
|
||||
description="Create a new batch for processing multiple API requests.",
|
||||
responses={
|
||||
200: {"description": "The created batch object."},
|
||||
409: {"description": "Conflict: The idempotency key was previously used with different parameters."},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
request: Annotated[CreateBatchRequest, Body(...)],
|
||||
) -> BatchObject:
|
||||
return await impl.create_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
response_model=BatchObject,
|
||||
summary="Retrieve information about a specific batch.",
|
||||
description="Retrieve information about a specific batch.",
|
||||
responses={
|
||||
200: {"description": "The batch object."},
|
||||
},
|
||||
)
|
||||
async def retrieve_batch(
|
||||
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
|
||||
) -> BatchObject:
|
||||
return await impl.retrieve_batch(request)
|
||||
|
||||
@router.post(
|
||||
"/batches/{batch_id}/cancel",
|
||||
response_model=BatchObject,
|
||||
summary="Cancel a batch that is in progress.",
|
||||
description="Cancel a batch that is in progress.",
|
||||
responses={
|
||||
200: {"description": "The updated batch object."},
|
||||
},
|
||||
)
|
||||
async def cancel_batch(
|
||||
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
|
||||
) -> BatchObject:
|
||||
return await impl.cancel_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches",
|
||||
response_model=ListBatchesResponse,
|
||||
summary="List all batches for the current user.",
|
||||
description="List all batches for the current user.",
|
||||
responses={
|
||||
200: {"description": "A list of batch objects."},
|
||||
},
|
||||
)
|
||||
async def list_batches(
|
||||
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
|
||||
) -> ListBatchesResponse:
|
||||
return await impl.list_batches(request)
|
||||
|
||||
return router
|
||||
78
src/llama_stack_api/batches/models.py
Normal file
78
src/llama_stack_api/batches/models.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Pydantic models for Batches API requests and responses.
|
||||
|
||||
This module defines the request and response models for the Batches API
|
||||
using Pydantic with Field descriptions for OpenAPI schema generation.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from openai.types import Batch as BatchObject
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CreateBatchRequest(BaseModel):
|
||||
"""Request model for creating a batch."""
|
||||
|
||||
input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.")
|
||||
endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.")
|
||||
completion_window: Literal["24h"] = Field(
|
||||
..., description="The time window within which the batch should be processed."
|
||||
)
|
||||
metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.")
|
||||
idempotency_key: str | None = Field(
|
||||
default=None, description="Optional idempotency key. When provided, enables idempotent behavior."
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesRequest(BaseModel):
|
||||
"""Request model for listing batches."""
|
||||
|
||||
after: str | None = Field(
|
||||
default=None, description="Optional cursor for pagination. Returns batches after this ID."
|
||||
)
|
||||
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RetrieveBatchRequest(BaseModel):
|
||||
"""Request model for retrieving a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CancelBatchRequest(BaseModel):
|
||||
"""Request model for canceling a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to cancel.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
"BatchObject",
|
||||
]
|
||||
|
|
@ -24,6 +24,7 @@ classifiers = [
|
|||
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0,<1.0",
|
||||
"pydantic>=2.11.9",
|
||||
"jsonschema",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
|
|
|
|||
155
src/llama_stack_api/router_utils.py
Normal file
155
src/llama_stack_api/router_utils.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Utilities for creating FastAPI routers with standard error responses.
|
||||
|
||||
This module provides standard error response definitions for FastAPI routers.
|
||||
These responses use OpenAPI $ref references to component responses defined
|
||||
in the OpenAPI specification.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, Any, TypeVar
|
||||
|
||||
from fastapi import Path, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
standard_responses: dict[int | str, dict[str, Any]] = {
|
||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
||||
500: {"$ref": "#/components/responses/InternalServerError500"},
|
||||
"default": {"$ref": "#/components/responses/DefaultError"},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||
"""Create a FastAPI dependency function from a Pydantic model for query parameters.
|
||||
|
||||
FastAPI does not natively support using Pydantic models as query parameters
|
||||
without a dependency function. Using a dependency function typically leads to
|
||||
duplication: field types, default values, and descriptions must be repeated in
|
||||
`Query(...)` annotations even though they already exist in the Pydantic model.
|
||||
|
||||
This function automatically generates a dependency function that extracts query parameters
|
||||
from the request and constructs an instance of the Pydantic model. The descriptions and
|
||||
defaults are automatically extracted from the model's Field definitions, making the model
|
||||
the single source of truth.
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model class to create a dependency for
|
||||
|
||||
Returns:
|
||||
A dependency function that can be used with FastAPI's Depends()
|
||||
```
|
||||
"""
|
||||
# Build function signature dynamically from model fields
|
||||
annotations: dict[str, Any] = {}
|
||||
defaults: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_info in model_class.model_fields.items():
|
||||
# Extract description from Field
|
||||
description = field_info.description
|
||||
|
||||
# Create Query annotation with description from model
|
||||
query_annotation = Query(description=description) if description else Query()
|
||||
|
||||
# Create Annotated type with Query
|
||||
field_type = field_info.annotation
|
||||
annotations[field_name] = Annotated[field_type, query_annotation]
|
||||
|
||||
# Set default value from model
|
||||
if field_info.default is not inspect.Parameter.empty:
|
||||
defaults[field_name] = field_info.default
|
||||
|
||||
# Create the dependency function dynamically
|
||||
def dependency_func(**kwargs: Any) -> T:
|
||||
return model_class(**kwargs)
|
||||
|
||||
# Set function signature
|
||||
sig_params = []
|
||||
for field_name, field_type in annotations.items():
|
||||
default = defaults.get(field_name, inspect.Parameter.empty)
|
||||
param = inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default,
|
||||
annotation=field_type,
|
||||
)
|
||||
sig_params.append(param)
|
||||
|
||||
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
||||
# they are standard Python function attributes that exist on all callable objects at runtime.
|
||||
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
||||
dependency_func.__signature__ = inspect.Signature(sig_params) # type: ignore[attr-defined]
|
||||
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
||||
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
||||
|
||||
return dependency_func
|
||||
|
||||
|
||||
def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||
"""Create a FastAPI dependency function from a Pydantic model for path parameters.
|
||||
|
||||
FastAPI requires path parameters to be explicitly annotated with `Path()`. When using
|
||||
a Pydantic model that contains path parameters, you typically need a dependency function
|
||||
that extracts the path parameter and constructs the model. This leads to duplication:
|
||||
the parameter name, type, and description must be repeated in `Path(...)` annotations
|
||||
even though they already exist in the Pydantic model.
|
||||
|
||||
This function automatically generates a dependency function that extracts path parameters
|
||||
from the request and constructs an instance of the Pydantic model. The descriptions are
|
||||
automatically extracted from the model's Field definitions, making the model the single
|
||||
source of truth.
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model class to create a dependency for. The model should
|
||||
have exactly one field that represents the path parameter.
|
||||
|
||||
Returns:
|
||||
A dependency function that can be used with FastAPI's Depends()
|
||||
```
|
||||
"""
|
||||
# Get the single field from the model (path parameter models typically have one field)
|
||||
if len(model_class.model_fields) != 1:
|
||||
raise ValueError(
|
||||
f"Path parameter model {model_class.__name__} must have exactly one field, "
|
||||
f"but has {len(model_class.model_fields)} fields"
|
||||
)
|
||||
|
||||
field_name, field_info = next(iter(model_class.model_fields.items()))
|
||||
|
||||
# Extract description from Field
|
||||
description = field_info.description
|
||||
|
||||
# Create Path annotation with description from model
|
||||
path_annotation = Path(description=description) if description else Path()
|
||||
|
||||
# Create Annotated type with Path
|
||||
field_type = field_info.annotation
|
||||
annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]}
|
||||
|
||||
# Create the dependency function dynamically
|
||||
def dependency_func(**kwargs: Any) -> T:
|
||||
return model_class(**kwargs)
|
||||
|
||||
# Set function signature
|
||||
param = inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=annotations[field_name],
|
||||
)
|
||||
|
||||
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
||||
# they are standard Python function attributes that exist on all callable objects at runtime.
|
||||
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
||||
dependency_func.__signature__ = inspect.Signature([param]) # type: ignore[attr-defined]
|
||||
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
||||
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
||||
|
||||
return dependency_func
|
||||
|
|
@ -58,8 +58,15 @@ import json
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
|
|
@ -169,7 +176,7 @@ class TestReferenceBatchesImpl:
|
|||
|
||||
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch creation and retrieval."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
|
|
@ -184,7 +191,7 @@ class TestReferenceBatchesImpl:
|
|||
assert isinstance(created_batch.created_at, int)
|
||||
assert created_batch.created_at > 0
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
||||
retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=created_batch.id))
|
||||
|
||||
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
|
|
@ -197,17 +204,15 @@ class TestReferenceBatchesImpl:
|
|||
async def test_create_batch_without_metadata(self, provider):
|
||||
"""Test batch creation without optional metadata."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
|
||||
assert batch.metadata is None
|
||||
|
||||
async def test_create_batch_completion_window(self, provider):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
with pytest.raises(ValueError, match="Invalid completion_window"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
||||
)
|
||||
with pytest.raises(ValidationError, match="completion_window"):
|
||||
CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
|
|
@ -219,37 +224,43 @@ class TestReferenceBatchesImpl:
|
|||
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||
"""Test batch creation with various invalid endpoints."""
|
||||
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
await provider.create_batch(
|
||||
CreateBatchRequest(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
)
|
||||
|
||||
async def test_create_batch_invalid_metadata(self, provider):
|
||||
"""Test that batch creation fails with invalid metadata."""
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
CreateBatchRequest(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={123: "invalid_key"}, # Non-string key
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
CreateBatchRequest(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"valid_key": 456}, # Non-string value
|
||||
)
|
||||
)
|
||||
|
||||
async def test_retrieve_batch_not_found(self, provider):
|
||||
"""Test error when retrieving non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.retrieve_batch("nonexistent_batch")
|
||||
await provider.retrieve_batch(RetrieveBatchRequest(batch_id="nonexistent_batch"))
|
||||
|
||||
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch cancellation."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
assert created_batch.status == "validating"
|
||||
|
||||
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
||||
cancelled_batch = await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
|
||||
|
||||
assert cancelled_batch.id == created_batch.id
|
||||
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||
|
|
@ -260,22 +271,22 @@ class TestReferenceBatchesImpl:
|
|||
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||
"""Test error when cancelling batch in final states."""
|
||||
provider.process_batches = False
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
# directly update status in kvstore
|
||||
await provider._update_batch(created_batch.id, status=status)
|
||||
|
||||
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||
await provider.cancel_batch(created_batch.id)
|
||||
await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
|
||||
|
||||
async def test_cancel_batch_not_found(self, provider):
|
||||
"""Test error when cancelling non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.cancel_batch("nonexistent_batch")
|
||||
await provider.cancel_batch(CancelBatchRequest(batch_id="nonexistent_batch"))
|
||||
|
||||
async def test_list_batches_empty(self, provider):
|
||||
"""Test listing batches when none exist."""
|
||||
response = await provider.list_batches()
|
||||
response = await provider.list_batches(ListBatchesRequest())
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.data == []
|
||||
|
|
@ -285,9 +296,9 @@ class TestReferenceBatchesImpl:
|
|||
|
||||
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||
"""Test listing batches with single batch."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
response = await provider.list_batches()
|
||||
response = await provider.list_batches(ListBatchesRequest())
|
||||
|
||||
assert len(response.data) == 1
|
||||
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||
|
|
@ -300,12 +311,12 @@ class TestReferenceBatchesImpl:
|
|||
"""Test listing multiple batches."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches()
|
||||
response = await provider.list_batches(ListBatchesRequest())
|
||||
|
||||
assert len(response.data) == 3
|
||||
|
||||
|
|
@ -321,12 +332,12 @@ class TestReferenceBatchesImpl:
|
|||
"""Test listing batches with limit parameter."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches(limit=2)
|
||||
response = await provider.list_batches(ListBatchesRequest(limit=2))
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert response.has_more is True
|
||||
|
|
@ -340,36 +351,36 @@ class TestReferenceBatchesImpl:
|
|||
"""Test listing batches with pagination using 'after' parameter."""
|
||||
for i in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
|
||||
# Get first page
|
||||
first_page = await provider.list_batches(limit=1)
|
||||
first_page = await provider.list_batches(ListBatchesRequest(limit=1))
|
||||
assert len(first_page.data) == 1
|
||||
assert first_page.has_more is True
|
||||
|
||||
# Get second page using 'after'
|
||||
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
||||
second_page = await provider.list_batches(ListBatchesRequest(limit=1, after=first_page.data[0].id))
|
||||
assert len(second_page.data) == 1
|
||||
assert second_page.data[0].id != first_page.data[0].id
|
||||
|
||||
# Verify we got the next batch in order
|
||||
all_batches = await provider.list_batches()
|
||||
all_batches = await provider.list_batches(ListBatchesRequest())
|
||||
expected_second_batch_id = all_batches.data[1].id
|
||||
assert second_page.data[0].id == expected_second_batch_id
|
||||
|
||||
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||
"""Test listing batches with invalid 'after' parameter."""
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
response = await provider.list_batches(after="nonexistent_batch")
|
||||
response = await provider.list_batches(ListBatchesRequest(after="nonexistent_batch"))
|
||||
|
||||
# Should return all batches (no filtering when 'after' batch not found)
|
||||
assert len(response.data) == 1
|
||||
|
||||
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||
"""Test that batches are properly persisted in kvstore."""
|
||||
batch = await provider.create_batch(**sample_batch_data)
|
||||
batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||
assert stored_data is not None
|
||||
|
|
@ -757,7 +768,7 @@ class TestReferenceBatchesImpl:
|
|||
|
||||
for _ in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
|
@ -767,8 +778,10 @@ class TestReferenceBatchesImpl:
|
|||
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||
batch = await provider.create_batch(
|
||||
CreateBatchRequest(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
)
|
||||
)
|
||||
assert batch.endpoint == "/v1/embeddings"
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ import asyncio
|
|||
import pytest
|
||||
|
||||
from llama_stack_api import ConflictError
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, RetrieveBatchRequest
|
||||
|
||||
|
||||
class TestReferenceBatchesIdempotency:
|
||||
|
|
@ -56,19 +57,23 @@ class TestReferenceBatchesIdempotency:
|
|||
del sample_batch_data["metadata"]
|
||||
|
||||
batch1 = await provider.create_batch(
|
||||
CreateBatchRequest(
|
||||
**sample_batch_data,
|
||||
metadata={"test": "value1", "other": "value2"},
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
)
|
||||
|
||||
# sleep for 1 second to allow created_at timestamps to be different
|
||||
await asyncio.sleep(1)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
CreateBatchRequest(
|
||||
**sample_batch_data,
|
||||
metadata={"other": "value2", "test": "value1"}, # Different order
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert batch1.id == batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
|
|
@ -77,23 +82,17 @@ class TestReferenceBatchesIdempotency:
|
|||
|
||||
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
|
||||
"""Test that different idempotency keys create different batches even with same params."""
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-A",
|
||||
)
|
||||
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-A"))
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-B",
|
||||
)
|
||||
batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-B"))
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
|
||||
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
|
||||
"""Test that batches without idempotency key create unique batches even with identical parameters."""
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
batch2 = await provider.create_batch(**sample_batch_data)
|
||||
batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
|
|
@ -117,12 +116,12 @@ class TestReferenceBatchesIdempotency:
|
|||
|
||||
sample_batch_data[param_name] = first_value
|
||||
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
|
||||
sample_batch_data[param_name] = second_value
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(batch1.id)
|
||||
retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=batch1.id))
|
||||
assert retrieved_batch.id == batch1.id
|
||||
assert getattr(retrieved_batch, param_name) == first_value
|
||||
|
|
|
|||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
|
|
@ -2292,6 +2292,7 @@ name = "llama-stack-api"
|
|||
version = "0.4.0.dev0"
|
||||
source = { editable = "src/llama_stack_api" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
|
|
@ -2300,6 +2301,7 @@ dependencies = [
|
|||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "fastapi", specifier = ">=0.115.0,<1.0" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue