From c8808b4700a256477512e14802747ac5a34cd987 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 12 Sep 2024 15:55:42 +0200 Subject: [PATCH] Move helper into impl file + fix merging conflicts --- llama_toolchain/core/distribution_registry.py | 8 +++--- .../inference/adapters/tgi/config.py | 5 ---- llama_toolchain/inference/adapters/tgi/tgi.py | 26 +++++++++---------- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/llama_toolchain/core/distribution_registry.py b/llama_toolchain/core/distribution_registry.py index 855fc6300..ca13dffbc 100644 --- a/llama_toolchain/core/distribution_registry.py +++ b/llama_toolchain/core/distribution_registry.py @@ -59,10 +59,10 @@ def available_distribution_specs() -> List[DistributionSpec]: }, ), DistributionSpec( - distribution_id="local-plus-tgi-inference", + distribution_type="local-plus-tgi-inference", description="Use TGI for running LLM inference", providers={ - Api.inference: remote_provider_id("tgi"), + Api.inference: remote_provider_type("tgi"), Api.safety: "meta-reference", Api.agentic_system: "meta-reference", Api.memory: "meta-reference-faiss", @@ -72,7 +72,9 @@ def available_distribution_specs() -> List[DistributionSpec]: @lru_cache() -def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]: +def resolve_distribution_spec( + distribution_type: str, +) -> Optional[DistributionSpec]: for spec in available_distribution_specs(): if spec.distribution_type == distribution_type: return spec diff --git a/llama_toolchain/inference/adapters/tgi/config.py b/llama_toolchain/inference/adapters/tgi/config.py index 22d4f757b..a0135dfdd 100644 --- a/llama_toolchain/inference/adapters/tgi/config.py +++ b/llama_toolchain/inference/adapters/tgi/config.py @@ -6,8 +6,6 @@ from typing import Optional -from huggingface_hub import HfApi - from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -29,6 +27,3 @@ class TGIImplConfig(BaseModel): def is_inference_endpoint(self) -> bool: return self.hf_endpoint_name is not None - - def get_namespace(self) -> str: - return HfApi().whoami()["name"] diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index 73a72e390..bb7b99d02 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -5,20 +5,15 @@ # the root directory of this source tree. -from typing import AsyncGenerator +from typing import Any, AsyncGenerator, Dict import requests -from huggingface_hub import InferenceClient + +from huggingface_hub import HfApi, InferenceClient from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer - -from llama_toolchain.inference.api import * -from llama_toolchain.inference.api.api import ( # noqa: F403 - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, -) +from llama_toolchain.inference.api import * # noqa: F403 from llama_toolchain.inference.prepare_messages import prepare_messages from .config import TGIImplConfig @@ -31,7 +26,6 @@ HF_SUPPORTED_MODELS = { class TGIAdapter(Inference): - def __init__(self, config: TGIImplConfig) -> None: self.config = config self.tokenizer = Tokenizer.get_instance() @@ -42,7 +36,10 @@ class TGIAdapter(Inference): return InferenceClient(model=self.config.url, token=self.config.api_token) def _get_endpoint_info(self) -> Dict[str, Any]: - return {**self.client.get_endpoint_info(), "inference_url": self.config.url} + return { + **self.client.get_endpoint_info(), + "inference_url": self.config.url, + } async def initialize(self) -> None: try: @@ -68,7 +65,7 @@ class TGIAdapter(Inference): import traceback traceback.print_exc() - raise RuntimeError(f"Error initializing TGIAdapter: {e}") + raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e async def shutdown(self) -> None: pass @@ -244,12 +241,15 @@ class InferenceEndpointAdapter(TGIAdapter): "or 'endpoint_name'" ) if "/" not in hf_endpoint_name: - hf_namespace: str = self.config.get_namespace() + hf_namespace: str = self.get_namespace() endpoint_path = f"{hf_namespace}/{hf_endpoint_name}" else: endpoint_path = hf_endpoint_name return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}" + def get_namespace(self) -> str: + return HfApi().whoami()["name"] + @property def client(self) -> InferenceClient: return InferenceClient(model=self.inference_url, token=self.config.api_token)