Move helper into impl file + fix merging conflicts

This commit is contained in:
Celina Hanouti 2024-09-12 15:55:42 +02:00
parent 04f0b8fe11
commit c8808b4700
3 changed files with 18 additions and 21 deletions

View file

@ -59,10 +59,10 @@ def available_distribution_specs() -> List[DistributionSpec]:
},
),
DistributionSpec(
distribution_id="local-plus-tgi-inference",
distribution_type="local-plus-tgi-inference",
description="Use TGI for running LLM inference",
providers={
Api.inference: remote_provider_id("tgi"),
Api.inference: remote_provider_type("tgi"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
@ -72,7 +72,9 @@ def available_distribution_specs() -> List[DistributionSpec]:
@lru_cache()
def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]:
def resolve_distribution_spec(
distribution_type: str,
) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.distribution_type == distribution_type:
return spec

View file

@ -6,8 +6,6 @@
from typing import Optional
from huggingface_hub import HfApi
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -29,6 +27,3 @@ class TGIImplConfig(BaseModel):
def is_inference_endpoint(self) -> bool:
return self.hf_endpoint_name is not None
def get_namespace(self) -> str:
return HfApi().whoami()["name"]

View file

@ -5,20 +5,15 @@
# the root directory of this source tree.
from typing import AsyncGenerator
from typing import Any, AsyncGenerator, Dict
import requests
from huggingface_hub import InferenceClient
from huggingface_hub import HfApi, InferenceClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.inference.api import *
from llama_toolchain.inference.api.api import ( # noqa: F403
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
)
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import TGIImplConfig
@ -31,7 +26,6 @@ HF_SUPPORTED_MODELS = {
class TGIAdapter(Inference):
def __init__(self, config: TGIImplConfig) -> None:
self.config = config
self.tokenizer = Tokenizer.get_instance()
@ -42,7 +36,10 @@ class TGIAdapter(Inference):
return InferenceClient(model=self.config.url, token=self.config.api_token)
def _get_endpoint_info(self) -> Dict[str, Any]:
return {**self.client.get_endpoint_info(), "inference_url": self.config.url}
return {
**self.client.get_endpoint_info(),
"inference_url": self.config.url,
}
async def initialize(self) -> None:
try:
@ -68,7 +65,7 @@ class TGIAdapter(Inference):
import traceback
traceback.print_exc()
raise RuntimeError(f"Error initializing TGIAdapter: {e}")
raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e
async def shutdown(self) -> None:
pass
@ -244,12 +241,15 @@ class InferenceEndpointAdapter(TGIAdapter):
"or 'endpoint_name'"
)
if "/" not in hf_endpoint_name:
hf_namespace: str = self.config.get_namespace()
hf_namespace: str = self.get_namespace()
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
else:
endpoint_path = hf_endpoint_name
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
def get_namespace(self) -> str:
return HfApi().whoami()["name"]
@property
def client(self) -> InferenceClient:
return InferenceClient(model=self.inference_url, token=self.config.api_token)