mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
dropped impls for hf serverless and hf endpoint
This commit is contained in:
parent
1b15df8d1d
commit
650cbc395d
4 changed files with 6 additions and 44 deletions
|
@ -4,23 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Union
|
from .config import TGIImplConfig
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: TGIImplConfig, _deps):
|
||||||
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
|
from .tgi import TGIAdapter
|
||||||
_deps,
|
|
||||||
):
|
|
||||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
|
||||||
|
|
||||||
if isinstance(config, TGIImplConfig):
|
if isinstance(config, TGIImplConfig):
|
||||||
impl = TGIAdapter()
|
impl = TGIAdapter()
|
||||||
elif isinstance(config, InferenceAPIImplConfig):
|
|
||||||
impl = InferenceAPIAdapter()
|
|
||||||
elif isinstance(config, InferenceEndpointImplConfig):
|
|
||||||
impl = InferenceEndpointAdapter()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
|
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -52,7 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
completion_request_to_prompt_model_input_info,
|
completion_request_to_prompt_model_input_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import TGIImplConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
@ -250,33 +250,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
logger.info(f"Initializing TGI client with url={config.url}")
|
logger.info(f"Initializing TGI client with url={config.url}")
|
||||||
# unfortunately, the TGI async client does not work well with proxies
|
|
||||||
# so using sync client for now instead
|
|
||||||
self.client = AsyncInferenceClient(model=f"{config.url}")
|
self.client = AsyncInferenceClient(model=f"{config.url}")
|
||||||
|
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
|
|
||||||
|
|
||||||
class InferenceAPIAdapter(_HfAdapter):
|
|
||||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
|
||||||
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
|
||||||
self.model_id = endpoint_info["model_id"]
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceEndpointAdapter(_HfAdapter):
|
|
||||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
|
||||||
# Get the inference endpoint details
|
|
||||||
api = HfApi(token=config.api_token.get_secret_value())
|
|
||||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
|
||||||
|
|
||||||
# Wait for the endpoint to be ready (if not already)
|
|
||||||
endpoint.wait(timeout=60)
|
|
||||||
|
|
||||||
# Initialize the adapter
|
|
||||||
self.client = endpoint.async_client
|
|
||||||
self.model_id = endpoint.repository
|
|
||||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
|
||||||
|
|
|
@ -1157,10 +1157,6 @@ async def convert_chat_completion_request_to_openai_params(
|
||||||
# Apply additionalProperties: False recursively to all objects
|
# Apply additionalProperties: False recursively to all objects
|
||||||
fmt = _add_additional_properties_recursive(fmt)
|
fmt = _add_additional_properties_recursive(fmt)
|
||||||
|
|
||||||
from rich.pretty import pprint
|
|
||||||
|
|
||||||
pprint(fmt)
|
|
||||||
|
|
||||||
input_dict["response_format"] = {
|
input_dict["response_format"] = {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
"json_schema": {
|
"json_schema": {
|
||||||
|
|
|
@ -87,7 +87,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||||
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", ["url"])
|
@pytest.mark.parametrize("type_", ["url", "data"])
|
||||||
def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
|
def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
|
||||||
image_spec = {
|
image_spec = {
|
||||||
"url": {
|
"url": {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue