diff --git a/llama_toolchain/inference/adapters/tgi/__init__.py b/llama_toolchain/inference/adapters/tgi/__init__.py index 6c9e128bc..743807836 100644 --- a/llama_toolchain/inference/adapters/tgi/__init__.py +++ b/llama_toolchain/inference/adapters/tgi/__init__.py @@ -11,7 +11,7 @@ from .tgi import InferenceEndpointAdapter, TGIAdapter async def get_adapter_impl(config: TGIImplConfig, _deps): assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}" - if config.is_local_tgi(): + if config.url is not None: impl = TGIAdapter(config) elif config.is_inference_endpoint(): impl = InferenceEndpointAdapter(config) diff --git a/llama_toolchain/inference/adapters/tgi/config.py b/llama_toolchain/inference/adapters/tgi/config.py index cb1ec8400..22d4f757b 100644 --- a/llama_toolchain/inference/adapters/tgi/config.py +++ b/llama_toolchain/inference/adapters/tgi/config.py @@ -7,6 +7,7 @@ from typing import Optional from huggingface_hub import HfApi + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -31,6 +32,3 @@ class TGIImplConfig(BaseModel): def get_namespace(self) -> str: return HfApi().whoami()["name"] - - def is_local_tgi(self) -> bool: - return self.url is not None and self.url.startswith("http://localhost")