Use HfApi to get the namespace when not provide in the hf endpoint name

This commit is contained in:
Celina Hanouti 2024-09-09 18:59:10 +02:00
parent 3d660ad938
commit fff1b6d6bf
3 changed files with 20 additions and 8 deletions

View file

@ -235,7 +235,19 @@ class TGIAdapter(Inference):
class InferenceEndpointAdapter(TGIAdapter):
def __init__(self, config: TGIImplConfig) -> None:
super().__init__(config)
self.config.url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{config.hf_namespace}/{config.hf_endpoint_name}"
self.config.url = self._construct_endpoint_url(config.hf_endpoint_name)
def _construct_endpoint_url(self, hf_endpoint_name: str) -> str:
assert hf_endpoint_name.count("/") <= 1, (
"Endpoint name must be in the format of 'namespace/endpoint_name' "
"or 'endpoint_name'"
)
if "/" not in hf_endpoint_name:
hf_namespace: str = self.config.get_namespace()
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
else:
endpoint_path = hf_endpoint_name
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
@property
def client(self) -> InferenceClient: