dropped impls for hf serverless and hf endpoint

This commit is contained in:
Hardik Shah 2025-03-28 22:38:16 -07:00
parent 1b15df8d1d
commit 650cbc395d
4 changed files with 6 additions and 44 deletions

View file

@ -4,23 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Union from .config import TGIImplConfig
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
async def get_adapter_impl( async def get_adapter_impl(config: TGIImplConfig, _deps):
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], from .tgi import TGIAdapter
_deps,
):
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
if isinstance(config, TGIImplConfig): if isinstance(config, TGIImplConfig):
impl = TGIAdapter() impl = TGIAdapter()
elif isinstance(config, InferenceAPIImplConfig):
impl = InferenceAPIAdapter()
elif isinstance(config, InferenceEndpointImplConfig):
impl = InferenceEndpointAdapter()
else: else:
raise ValueError( raise ValueError(
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}." f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."

View file

@ -7,7 +7,7 @@
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -52,7 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt_model_input_info, completion_request_to_prompt_model_input_info,
) )
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import TGIImplConfig
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference")
@ -250,33 +250,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter): class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
logger.info(f"Initializing TGI client with url={config.url}") logger.info(f"Initializing TGI client with url={config.url}")
# unfortunately, the TGI async client does not work well with proxies
# so using sync client for now instead
self.client = AsyncInferenceClient(model=f"{config.url}") self.client = AsyncInferenceClient(model=f"{config.url}")
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details
api = HfApi(token=config.api_token.get_secret_value())
endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already)
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])

View file

@ -1157,10 +1157,6 @@ async def convert_chat_completion_request_to_openai_params(
# Apply additionalProperties: False recursively to all objects # Apply additionalProperties: False recursively to all objects
fmt = _add_additional_properties_recursive(fmt) fmt = _add_additional_properties_recursive(fmt)
from rich.pretty import pprint
pprint(fmt)
input_dict["response_format"] = { input_dict["response_format"] = {
"type": "json_schema", "type": "json_schema",
"json_schema": { "json_schema": {

View file

@ -87,7 +87,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id):
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
@pytest.mark.parametrize("type_", ["url"]) @pytest.mark.parametrize("type_", ["url", "data"])
def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_): def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
image_spec = { image_spec = {
"url": { "url": {