mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Move helper into impl file + fix merging conflicts
This commit is contained in:
parent
04f0b8fe11
commit
c8808b4700
3 changed files with 18 additions and 21 deletions
|
@ -59,10 +59,10 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_id="local-plus-tgi-inference",
|
||||
distribution_type="local-plus-tgi-inference",
|
||||
description="Use TGI for running LLM inference",
|
||||
providers={
|
||||
Api.inference: remote_provider_id("tgi"),
|
||||
Api.inference: remote_provider_type("tgi"),
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
|
@ -72,7 +72,9 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
|
||||
|
||||
@lru_cache()
|
||||
def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]:
|
||||
def resolve_distribution_spec(
|
||||
distribution_type: str,
|
||||
) -> Optional[DistributionSpec]:
|
||||
for spec in available_distribution_specs():
|
||||
if spec.distribution_type == distribution_type:
|
||||
return spec
|
||||
|
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -29,6 +27,3 @@ class TGIImplConfig(BaseModel):
|
|||
|
||||
def is_inference_endpoint(self) -> bool:
|
||||
return self.hf_endpoint_name is not None
|
||||
|
||||
def get_namespace(self) -> str:
|
||||
return HfApi().whoami()["name"]
|
||||
|
|
|
@ -5,20 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
|
||||
import requests
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from huggingface_hub import HfApi, InferenceClient
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_toolchain.inference.api import *
|
||||
from llama_toolchain.inference.api.api import ( # noqa: F403
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import TGIImplConfig
|
||||
|
@ -31,7 +26,6 @@ HF_SUPPORTED_MODELS = {
|
|||
|
||||
|
||||
class TGIAdapter(Inference):
|
||||
|
||||
def __init__(self, config: TGIImplConfig) -> None:
|
||||
self.config = config
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
|
@ -42,7 +36,10 @@ class TGIAdapter(Inference):
|
|||
return InferenceClient(model=self.config.url, token=self.config.api_token)
|
||||
|
||||
def _get_endpoint_info(self) -> Dict[str, Any]:
|
||||
return {**self.client.get_endpoint_info(), "inference_url": self.config.url}
|
||||
return {
|
||||
**self.client.get_endpoint_info(),
|
||||
"inference_url": self.config.url,
|
||||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
|
@ -68,7 +65,7 @@ class TGIAdapter(Inference):
|
|||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(f"Error initializing TGIAdapter: {e}")
|
||||
raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -244,12 +241,15 @@ class InferenceEndpointAdapter(TGIAdapter):
|
|||
"or 'endpoint_name'"
|
||||
)
|
||||
if "/" not in hf_endpoint_name:
|
||||
hf_namespace: str = self.config.get_namespace()
|
||||
hf_namespace: str = self.get_namespace()
|
||||
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
|
||||
else:
|
||||
endpoint_path = hf_endpoint_name
|
||||
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
|
||||
|
||||
def get_namespace(self) -> str:
|
||||
return HfApi().whoami()["name"]
|
||||
|
||||
@property
|
||||
def client(self) -> InferenceClient:
|
||||
return InferenceClient(model=self.inference_url, token=self.config.api_token)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue