diff --git a/llama_stack/providers/remote/inference/tgi/__init__.py b/llama_stack/providers/remote/inference/tgi/__init__.py index 834e51324..436749010 100644 --- a/llama_stack/providers/remote/inference/tgi/__init__.py +++ b/llama_stack/providers/remote/inference/tgi/__init__.py @@ -4,23 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Union - -from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig +from .config import TGIImplConfig -async def get_adapter_impl( - config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], - _deps, -): - from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter +async def get_adapter_impl(config: TGIImplConfig, _deps): + from .tgi import TGIAdapter if isinstance(config, TGIImplConfig): impl = TGIAdapter() - elif isinstance(config, InferenceAPIImplConfig): - impl = InferenceAPIAdapter() - elif isinstance(config, InferenceEndpointImplConfig): - impl = InferenceEndpointAdapter() else: raise ValueError( f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}." diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 0ef936925..3fe7a30f7 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -7,7 +7,7 @@ from typing import AsyncGenerator, List, Optional -from huggingface_hub import AsyncInferenceClient, HfApi +from huggingface_hub import AsyncInferenceClient from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -52,7 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( completion_request_to_prompt_model_input_info, ) -from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig +from .config import TGIImplConfig logger = get_logger(name=__name__, category="inference") @@ -250,33 +250,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: logger.info(f"Initializing TGI client with url={config.url}") - # unfortunately, the TGI async client does not work well with proxies - # so using sync client for now instead self.client = AsyncInferenceClient(model=f"{config.url}") endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] - - -class InferenceAPIAdapter(_HfAdapter): - async def initialize(self, config: InferenceAPIImplConfig) -> None: - self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value()) - endpoint_info = await self.client.get_endpoint_info() - self.max_tokens = endpoint_info["max_total_tokens"] - self.model_id = endpoint_info["model_id"] - - -class InferenceEndpointAdapter(_HfAdapter): - async def initialize(self, config: InferenceEndpointImplConfig) -> None: - # Get the inference endpoint details - api = HfApi(token=config.api_token.get_secret_value()) - endpoint = api.get_inference_endpoint(config.endpoint_name) - - # Wait for the endpoint to be ready (if not already) - endpoint.wait(timeout=60) - - # Initialize the adapter - self.client = endpoint.async_client - self.model_id = endpoint.repository - self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 66de06c6f..3cc38dc29 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1157,10 +1157,6 @@ async def convert_chat_completion_request_to_openai_params( # Apply additionalProperties: False recursively to all objects fmt = _add_additional_properties_recursive(fmt) - from rich.pretty import pprint - - pprint(fmt) - input_dict["response_format"] = { "type": "json_schema", "json_schema": { diff --git a/tests/integration/inference/test_vision_inference.py b/tests/integration/inference/test_vision_inference.py index 70068aa29..9f6fb0478 100644 --- a/tests/integration/inference/test_vision_inference.py +++ b/tests/integration/inference/test_vision_inference.py @@ -87,7 +87,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id): assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) -@pytest.mark.parametrize("type_", ["url"]) +@pytest.mark.parametrize("type_", ["url", "data"]) def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_): image_spec = { "url": {