dropped impls for hf serverless and hf endpoint

This commit is contained in:
Hardik Shah 2025-03-28 22:38:16 -07:00
parent 1b15df8d1d
commit 650cbc395d
4 changed files with 6 additions and 44 deletions

View file

@ -7,7 +7,7 @@
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from huggingface_hub import AsyncInferenceClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -52,7 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
from .config import TGIImplConfig
logger = get_logger(name=__name__, category="inference")
@ -250,33 +250,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
logger.info(f"Initializing TGI client with url={config.url}")
# unfortunately, the TGI async client does not work well with proxies
# so using sync client for now instead
self.client = AsyncInferenceClient(model=f"{config.url}")
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details
api = HfApi(token=config.api_token.get_secret_value())
endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already)
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])