From 7a1b60fccf656ab6606eaf83c1f6bde3efbe0a04 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 16 Dec 2024 22:30:10 -0800 Subject: [PATCH] Minor --- llama_stack/distribution/library_client.py | 7 ++++--- tests/client-sdk/inference/test_inference.py | 13 ------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 50b867366..14f62e3a6 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -25,7 +25,6 @@ from llama_stack_client import ( AsyncStream, LlamaStackClient, NOT_GIVEN, - Stream, ) from pydantic import BaseModel, TypeAdapter from rich.console import Console @@ -370,8 +369,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): json=options.json_data, ), ) - origin = get_origin(stream_cls) - assert origin is Stream + + # we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient + # however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream) + # so we need to convert it to AsyncStream args = get_args(stream_cls) stream_cls = AsyncStream[args[0]] response = AsyncAPIResponse( diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index d00ae12a8..ea9cfb8ae 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -4,23 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import sys -import traceback -import warnings - import pytest from llama_stack_client.lib.inference.event_logger import EventLogger -def warn_with_traceback(message, category, filename, lineno, file=None, line=None): - log = file if hasattr(file, "write") else sys.stderr - traceback.print_stack(file=log) - log.write(warnings.formatwarning(message, category, filename, lineno, line)) - - -warnings.showwarning = warn_with_traceback - - def test_text_chat_completion(llama_stack_client): # non-streaming available_models = [