mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
feat(llama_stack/apis/common/responses.py, llama_stack/distribution/server/server.py, tests/integration/datasets/test_datasets.py): enhance pagination response with next URL
Add an optional URL field to the PaginatedResponse model to facilitate pagination in API responses. Update the server logic to populate this URL based on the current request parameters when more data is available. Enhance tests to verify the presence and correctness of the URL in pagination scenarios.
This commit is contained in:
parent
4597145011
commit
0f765f00c2
3 changed files with 106 additions and 11 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,9 @@ class PaginatedResponse(BaseModel):
|
|||
|
||||
:param data: The list of items for the current page
|
||||
:param has_more: Whether there are more items available after this set
|
||||
:param url: Optional URL to fetch the next page of results. Only present if has_more is true.
|
||||
"""
|
||||
|
||||
data: list[dict[str, Any]]
|
||||
data: List[Dict[str, Any]]
|
||||
has_more: bool
|
||||
url: Optional[str] = None
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -12,6 +13,7 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
import urllib.parse
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
|
@ -25,6 +27,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
|
@ -202,8 +205,46 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
# Execute the actual implementation function
|
||||
result_value = func(**kwargs)
|
||||
value = await maybe_await(result_value)
|
||||
|
||||
# Check if the result is a PaginatedResponse and needs a next URL
|
||||
if isinstance(value, PaginatedResponse) and value.has_more:
|
||||
try:
|
||||
# Retrieve pagination params from original call kwargs
|
||||
limit = kwargs.get("limit")
|
||||
start_index = kwargs.get("start_index", 0) # Default to 0 if not provided
|
||||
|
||||
# Ensure params are integers
|
||||
limit = int(limit) if limit is not None else None
|
||||
start_index = int(start_index) if start_index is not None else 0
|
||||
|
||||
if limit is not None and limit > 0:
|
||||
next_start_index = start_index + limit
|
||||
|
||||
# Build query params for the next page URL
|
||||
next_params = dict(request.query_params)
|
||||
next_params['start_index'] = str(next_start_index)
|
||||
# Ensure limit is also included/updated if necessary
|
||||
next_params['limit'] = str(limit)
|
||||
|
||||
# Construct the full URL for the next page
|
||||
next_url = str(request.url.replace(query=urllib.parse.urlencode(next_params)))
|
||||
# Assign the URL to the response object (assuming 'url' field exists)
|
||||
value.url = next_url
|
||||
else:
|
||||
# Log a warning if limit is missing or invalid for pagination that has_more
|
||||
logger.warning(f"PaginatedResponse has_more=True but limit is missing or invalid for request: {request.url}")
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Error processing pagination parameters for URL generation: {e}", exc_info=True)
|
||||
except AttributeError:
|
||||
# This might happen if PaginatedResponse doesn't have the 'url' field yet.
|
||||
# Should not happen if Task 1 was completed correctly.
|
||||
logger.error("PaginatedResponse object does not have a 'url' attribute. Ensure the model definition is updated.", exc_info=True)
|
||||
|
||||
return value # Return the (potentially modified) value
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -8,6 +9,7 @@
|
|||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -33,7 +35,7 @@ def data_url_from_file(file_path: str) -> str:
|
|||
|
||||
@pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub")
|
||||
@pytest.mark.parametrize(
|
||||
"purpose, source, provider_id, limit",
|
||||
"purpose, source, provider_id, limit, total_expected",
|
||||
[
|
||||
(
|
||||
"eval/messages-answer",
|
||||
|
@ -42,7 +44,8 @@ def data_url_from_file(file_path: str) -> str:
|
|||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||
},
|
||||
"huggingface",
|
||||
10,
|
||||
5, # Request 5, expect more
|
||||
10, # Assume total > 5
|
||||
),
|
||||
(
|
||||
"eval/messages-answer",
|
||||
|
@ -62,10 +65,20 @@ def data_url_from_file(file_path: str) -> str:
|
|||
],
|
||||
"answer": "Paris",
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Third message",
|
||||
}
|
||||
],
|
||||
"answer": "Third answer",
|
||||
},
|
||||
],
|
||||
},
|
||||
"localfs",
|
||||
2,
|
||||
2, # Request 2, expect more
|
||||
3, # Total is 3
|
||||
),
|
||||
(
|
||||
"eval/messages-answer",
|
||||
|
@ -74,23 +87,62 @@ def data_url_from_file(file_path: str) -> str:
|
|||
"uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
|
||||
},
|
||||
"localfs",
|
||||
5,
|
||||
3, # Request 3, expect more
|
||||
5, # Total is 5
|
||||
),
|
||||
(
|
||||
"eval/messages-answer",
|
||||
{
|
||||
"type": "uri",
|
||||
"uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
|
||||
},
|
||||
"localfs",
|
||||
5, # Request all 5, expect no more
|
||||
5, # Total is 5
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit):
|
||||
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit, total_expected):
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
)
|
||||
assert dataset.identifier is not None
|
||||
assert dataset.provider_id == provider_id
|
||||
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit)
|
||||
assert len(iterrow_response.data) == limit
|
||||
|
||||
# Initial request
|
||||
start_index = 0
|
||||
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit, start_index=start_index)
|
||||
assert len(iterrow_response.data) == min(limit, total_expected)
|
||||
|
||||
# Check pagination fields
|
||||
expected_has_more = (start_index + limit) < total_expected
|
||||
assert iterrow_response.has_more == expected_has_more
|
||||
|
||||
if expected_has_more:
|
||||
assert hasattr(iterrow_response, "url"), "PaginatedResponse should have a 'url' field when has_more is True"
|
||||
assert iterrow_response.url is not None, "PaginatedResponse url should not be None when has_more is True"
|
||||
# Parse the URL to check parameters
|
||||
parsed_url = urlparse(iterrow_response.url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
assert "start_index" in query_params, "Next page URL must contain start_index"
|
||||
assert int(query_params["start_index"][0]) == start_index + limit, "Next page URL start_index is incorrect"
|
||||
assert "limit" in query_params, "Next page URL must contain limit"
|
||||
assert int(query_params["limit"][0]) == limit, "Next page URL limit is incorrect"
|
||||
assert parsed_url.path == f"/datasets/{dataset.identifier}/iterrows", "Next page URL path is incorrect"
|
||||
|
||||
# Optionally, make a request to the next page URL (requires client base_url to be set)
|
||||
# This is more complex as it bypasses the client method
|
||||
|
||||
else:
|
||||
assert not hasattr(iterrow_response, "url") or iterrow_response.url is None, "PaginatedResponse url should be None or missing when has_more is False"
|
||||
|
||||
|
||||
# List and check presence
|
||||
dataset_list = llama_stack_client.datasets.list()
|
||||
assert dataset.identifier in [d.identifier for d in dataset_list]
|
||||
|
||||
# Unregister and check absence
|
||||
llama_stack_client.datasets.unregister(dataset.identifier)
|
||||
dataset_list = llama_stack_client.datasets.list()
|
||||
assert dataset.identifier not in [d.identifier for d in dataset_list]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue