forked from phoenix-oss/llama-stack-mirror
* Use huggingface_hub inference client for TGI inference * Update the default value for TGI URL * Use InferenceClient.text_generation for TGI inference * Fixes post-review and split TGI adapter into local and Inference Endpoints ones * Update CLI reference and add typing * Rename TGI Adapter class * Use HfApi to get the namespace when not provide in the hf endpoint name * Remove unecessary method argument * Improve TGI adapter initialization condition * Move helper into impl file + fix merging conflicts
24 lines
806 B
Python
24 lines
806 B
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from .config import TGIImplConfig
|
|
from .tgi import InferenceEndpointAdapter, TGIAdapter
|
|
|
|
|
|
async def get_adapter_impl(config: TGIImplConfig, _deps):
|
|
assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}"
|
|
|
|
if config.url is not None:
|
|
impl = TGIAdapter(config)
|
|
elif config.is_inference_endpoint():
|
|
impl = InferenceEndpointAdapter(config)
|
|
else:
|
|
raise ValueError(
|
|
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
|
|
)
|
|
|
|
await impl.initialize()
|
|
return impl
|