feat: enable ls client for files tests (#2769)

# What does this PR do?
titled

## Test Plan
CI
This commit is contained in:
ehhuang 2025-07-18 12:10:30 -07:00 committed by GitHub
parent 874b1cb00f
commit 6d55f2f137
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 83 additions and 9 deletions

View file

@ -12,11 +12,13 @@ import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar, Union, get_args, get_origin from typing import Any, TypeVar, Union, get_args, get_origin
import httpx import httpx
import yaml import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_client import ( from llama_stack_client import (
NOT_GIVEN, NOT_GIVEN,
APIResponse, APIResponse,
@ -112,6 +114,27 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LibraryClientUploadFile:
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
def __init__(self, filename: str, content: bytes):
self.filename = filename
self.content = content
self.content_type = "application/octet-stream"
async def read(self) -> bytes:
return self.content
class LibraryClientHttpxResponse:
"""LibraryClient httpx Response object for FastAPI Response conversion."""
def __init__(self, response):
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
self.status_code = response.status_code
self.headers = response.headers
class LlamaStackAsLibraryClient(LlamaStackClient): class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__( def __init__(
self, self,
@ -295,6 +318,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return response return response
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
"""Handle file uploads from OpenAI client and add them to the request body."""
if not (hasattr(options, "files") and options.files):
return body, []
if not isinstance(options.files, list):
return body, []
field_names = []
for file_tuple in options.files:
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
continue
field_name = file_tuple[0]
file_object = file_tuple[1]
if isinstance(file_object, BytesIO):
file_object.seek(0)
file_content = file_object.read()
filename = getattr(file_object, "name", "uploaded_file")
field_names.append(field_name)
body[field_name] = LibraryClientUploadFile(filename, file_content)
return body, field_names
async def _call_non_streaming( async def _call_non_streaming(
self, self,
*, *,
@ -310,15 +358,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body)
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
await start_trace(route, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
try: try:
result = await matched_func(**body) result = await matched_func(**body)
finally: finally:
await end_trace() await end_trace()
# Handle FastAPI Response objects (e.g., from file content retrieval)
if isinstance(result, FastAPIResponse):
return LibraryClientHttpxResponse(result)
json_content = json.dumps(convert_pydantic_to_json_value(result)) json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=json_content.encode("utf-8"), content=json_content.encode("utf-8"),
@ -330,7 +386,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers or {}, headers=options.headers or {},
json=convert_pydantic_to_json_value(body), json=convert_pydantic_to_json_value(filtered_body),
), ),
) )
response = APIResponse( response = APIResponse(
@ -404,13 +460,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict: def _convert_body(
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
if not body: if not body:
return {} return {}
if self.route_impls is None: if self.route_impls is None:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
exclude_params = exclude_params or set()
func, _, _ = find_matching_route(method, path, self.route_impls) func, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
@ -422,6 +482,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items(): for param_name, param in sig.parameters.items():
if param_name in body: if param_name in body:
value = body.get(param_name) value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic(param.annotation, value) if param_name in exclude_params:
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
return converted_body return converted_body

View file

@ -7,15 +7,16 @@
from io import BytesIO from io import BytesIO
import pytest import pytest
from openai import OpenAI
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
def test_openai_client_basic_operations(openai_client, client_with_models): def test_openai_client_basic_operations(compat_client, client_with_models):
"""Test basic file operations through OpenAI client.""" """Test basic file operations through OpenAI client."""
if isinstance(client_with_models, LlamaStackAsLibraryClient): if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
pytest.skip("OpenAI files are not supported when testing with library client yet.") pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
client = openai_client client = compat_client
test_content = b"files test content" test_content = b"files test content"
@ -41,7 +42,12 @@ def test_openai_client_basic_operations(openai_client, client_with_models):
# Retrieve file content - OpenAI client returns httpx Response object # Retrieve file content - OpenAI client returns httpx Response object
content_response = client.files.content(uploaded_file.id) content_response = client.files.content(uploaded_file.id)
# The response is an httpx Response object with .content attribute containing bytes # The response is an httpx Response object with .content attribute containing bytes
content = content_response.content if isinstance(content_response, str):
# Llama Stack Client returns a str
# TODO: fix Llama Stack Client
content = bytes(content_response, "utf-8")
else:
content = content_response.content
assert content == test_content assert content == test_content
# Delete file # Delete file

View file

@ -257,6 +257,11 @@ def openai_client(client_with_models):
return OpenAI(base_url=base_url, api_key="fake") return OpenAI(base_url=base_url, api_key="fake")
@pytest.fixture(params=["openai_client", "llama_stack_client"])
def compat_client(request):
return request.getfixturevalue(request.param)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def cleanup_server_process(request): def cleanup_server_process(request):
"""Cleanup server process at the end of the test session.""" """Cleanup server process at the end of the test session."""