[Inference] Use huggingface_hub inference client for TGI adapter (#53)

* Use huggingface_hub inference client for TGI inference

* Update the default value for TGI URL

* Use InferenceClient.text_generation for TGI inference

* Fixes post-review and split TGI adapter into local and Inference Endpoints ones

* Update CLI reference and add typing

* Rename TGI Adapter class

* Use HfApi to get the namespace when not provide in the hf endpoint name

* Remove unecessary method argument

* Improve TGI adapter initialization condition

* Move helper into impl file + fix merging conflicts
This commit is contained in:
Celina Hanouti 2024-09-12 18:11:35 +02:00 committed by GitHub
parent 191cd28831
commit 736092f6bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 171 additions and 72 deletions

View file

@ -4,12 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_toolchain.core.datatypes import RemoteProviderConfig
from .config import TGIImplConfig
from .tgi import InferenceEndpointAdapter, TGIAdapter
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .tgi import TGIInferenceAdapter
async def get_adapter_impl(config: TGIImplConfig, _deps):
assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}"
if config.url is not None:
impl = TGIAdapter(config)
elif config.is_inference_endpoint():
impl = InferenceEndpointAdapter(config)
else:
raise ValueError(
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
)
impl = TGIInferenceAdapter(config.url)
await impl.initialize()
return impl