Move helper into impl file + fix merging conflicts

This commit is contained in:
Celina Hanouti 2024-09-12 15:55:42 +02:00
parent 04f0b8fe11
commit c8808b4700
3 changed files with 18 additions and 21 deletions

View file

@ -5,20 +5,15 @@
# the root directory of this source tree.
from typing import AsyncGenerator
from typing import Any, AsyncGenerator, Dict
import requests
from huggingface_hub import InferenceClient
from huggingface_hub import HfApi, InferenceClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.inference.api import *
from llama_toolchain.inference.api.api import ( # noqa: F403
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
)
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import TGIImplConfig
@ -31,7 +26,6 @@ HF_SUPPORTED_MODELS = {
class TGIAdapter(Inference):
def __init__(self, config: TGIImplConfig) -> None:
self.config = config
self.tokenizer = Tokenizer.get_instance()
@ -42,7 +36,10 @@ class TGIAdapter(Inference):
return InferenceClient(model=self.config.url, token=self.config.api_token)
def _get_endpoint_info(self) -> Dict[str, Any]:
return {**self.client.get_endpoint_info(), "inference_url": self.config.url}
return {
**self.client.get_endpoint_info(),
"inference_url": self.config.url,
}
async def initialize(self) -> None:
try:
@ -68,7 +65,7 @@ class TGIAdapter(Inference):
import traceback
traceback.print_exc()
raise RuntimeError(f"Error initializing TGIAdapter: {e}")
raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e
async def shutdown(self) -> None:
pass
@ -244,12 +241,15 @@ class InferenceEndpointAdapter(TGIAdapter):
"or 'endpoint_name'"
)
if "/" not in hf_endpoint_name:
hf_namespace: str = self.config.get_namespace()
hf_namespace: str = self.get_namespace()
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
else:
endpoint_path = hf_endpoint_name
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
def get_namespace(self) -> str:
return HfApi().whoami()["name"]
@property
def client(self) -> InferenceClient:
return InferenceClient(model=self.inference_url, token=self.config.api_token)