diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index cebfabba5..6c51dc2c7 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -12,11 +12,13 @@ import os import sys from concurrent.futures import ThreadPoolExecutor from enum import Enum +from io import BytesIO from pathlib import Path from typing import Any, TypeVar, Union, get_args, get_origin import httpx import yaml +from fastapi import Response as FastAPIResponse from llama_stack_client import ( NOT_GIVEN, APIResponse, @@ -112,6 +114,27 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e +class LibraryClientUploadFile: + """LibraryClient UploadFile object that mimics FastAPI's UploadFile interface.""" + + def __init__(self, filename: str, content: bytes): + self.filename = filename + self.content = content + self.content_type = "application/octet-stream" + + async def read(self) -> bytes: + return self.content + + +class LibraryClientHttpxResponse: + """LibraryClient httpx Response object for FastAPI Response conversion.""" + + def __init__(self, response): + self.content = response.body if isinstance(response.body, bytes) else response.body.encode() + self.status_code = response.status_code + self.headers = response.headers + + class LlamaStackAsLibraryClient(LlamaStackClient): def __init__( self, @@ -295,6 +318,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return response + def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]: + """Handle file uploads from OpenAI client and add them to the request body.""" + if not (hasattr(options, "files") and options.files): + return body, [] + + if not isinstance(options.files, list): + return body, [] + + field_names = [] + for file_tuple in options.files: + if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2): + continue + + field_name = file_tuple[0] + file_object = file_tuple[1] + + if isinstance(file_object, BytesIO): + file_object.seek(0) + file_content = file_object.read() + filename = getattr(file_object, "name", "uploaded_file") + field_names.append(field_name) + body[field_name] = LibraryClientUploadFile(filename, file_content) + + return body, field_names + async def _call_non_streaming( self, *, @@ -310,15 +358,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) body |= path_params - body = self._convert_body(path, options.method, body) + + body, field_names = self._handle_file_uploads(options, body) + + body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) await start_trace(route, {"__location__": "library_client"}) try: result = await matched_func(**body) finally: await end_trace() + # Handle FastAPI Response objects (e.g., from file content retrieval) + if isinstance(result, FastAPIResponse): + return LibraryClientHttpxResponse(result) + json_content = json.dumps(convert_pydantic_to_json_value(result)) + filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)} mock_response = httpx.Response( status_code=httpx.codes.OK, content=json_content.encode("utf-8"), @@ -330,7 +386,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): url=options.url, params=options.params, headers=options.headers or {}, - json=convert_pydantic_to_json_value(body), + json=convert_pydantic_to_json_value(filtered_body), ), ) response = APIResponse( @@ -404,13 +460,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return await response.parse() - def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict: + def _convert_body( + self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None + ) -> dict: if not body: return {} if self.route_impls is None: raise ValueError("Client not initialized") + exclude_params = exclude_params or set() + func, _, _ = find_matching_route(method, path, self.route_impls) sig = inspect.signature(func) @@ -422,6 +482,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): for param_name, param in sig.parameters.items(): if param_name in body: value = body.get(param_name) - converted_body[param_name] = convert_to_pydantic(param.annotation, value) + if param_name in exclude_params: + converted_body[param_name] = value + else: + converted_body[param_name] = convert_to_pydantic(param.annotation, value) return converted_body diff --git a/tests/integration/files/test_files.py b/tests/integration/files/test_files.py index 8375507dc..8547ef2f3 100644 --- a/tests/integration/files/test_files.py +++ b/tests/integration/files/test_files.py @@ -7,15 +7,16 @@ from io import BytesIO import pytest +from openai import OpenAI from llama_stack.distribution.library_client import LlamaStackAsLibraryClient -def test_openai_client_basic_operations(openai_client, client_with_models): +def test_openai_client_basic_operations(compat_client, client_with_models): """Test basic file operations through OpenAI client.""" - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI files are not supported when testing with library client yet.") - client = openai_client + if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI): + pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient") + client = compat_client test_content = b"files test content" @@ -41,7 +42,12 @@ def test_openai_client_basic_operations(openai_client, client_with_models): # Retrieve file content - OpenAI client returns httpx Response object content_response = client.files.content(uploaded_file.id) # The response is an httpx Response object with .content attribute containing bytes - content = content_response.content + if isinstance(content_response, str): + # Llama Stack Client returns a str + # TODO: fix Llama Stack Client + content = bytes(content_response, "utf-8") + else: + content = content_response.content assert content == test_content # Delete file diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 749793b64..f6b5b3026 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -257,6 +257,11 @@ def openai_client(client_with_models): return OpenAI(base_url=base_url, api_key="fake") +@pytest.fixture(params=["openai_client", "llama_stack_client"]) +def compat_client(request): + return request.getfixturevalue(request.param) + + @pytest.fixture(scope="session", autouse=True) def cleanup_server_process(request): """Cleanup server process at the end of the test session."""