mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Move helper into impl file + fix merging conflicts
This commit is contained in:
parent
04f0b8fe11
commit
c8808b4700
3 changed files with 18 additions and 21 deletions
|
@ -59,10 +59,10 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_id="local-plus-tgi-inference",
|
distribution_type="local-plus-tgi-inference",
|
||||||
description="Use TGI for running LLM inference",
|
description="Use TGI for running LLM inference",
|
||||||
providers={
|
providers={
|
||||||
Api.inference: remote_provider_id("tgi"),
|
Api.inference: remote_provider_type("tgi"),
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
|
@ -72,7 +72,9 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def resolve_distribution_spec(distribution_type: str) -> Optional[DistributionSpec]:
|
def resolve_distribution_spec(
|
||||||
|
distribution_type: str,
|
||||||
|
) -> Optional[DistributionSpec]:
|
||||||
for spec in available_distribution_specs():
|
for spec in available_distribution_specs():
|
||||||
if spec.distribution_type == distribution_type:
|
if spec.distribution_type == distribution_type:
|
||||||
return spec
|
return spec
|
||||||
|
|
|
@ -6,8 +6,6 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -29,6 +27,3 @@ class TGIImplConfig(BaseModel):
|
||||||
|
|
||||||
def is_inference_endpoint(self) -> bool:
|
def is_inference_endpoint(self) -> bool:
|
||||||
return self.hf_endpoint_name is not None
|
return self.hf_endpoint_name is not None
|
||||||
|
|
||||||
def get_namespace(self) -> str:
|
|
||||||
return HfApi().whoami()["name"]
|
|
||||||
|
|
|
@ -5,20 +5,15 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import Any, AsyncGenerator, Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import InferenceClient
|
|
||||||
|
from huggingface_hub import HfApi, InferenceClient
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import *
|
|
||||||
from llama_toolchain.inference.api.api import ( # noqa: F403
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
)
|
|
||||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import TGIImplConfig
|
from .config import TGIImplConfig
|
||||||
|
@ -31,7 +26,6 @@ HF_SUPPORTED_MODELS = {
|
||||||
|
|
||||||
|
|
||||||
class TGIAdapter(Inference):
|
class TGIAdapter(Inference):
|
||||||
|
|
||||||
def __init__(self, config: TGIImplConfig) -> None:
|
def __init__(self, config: TGIImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
@ -42,7 +36,10 @@ class TGIAdapter(Inference):
|
||||||
return InferenceClient(model=self.config.url, token=self.config.api_token)
|
return InferenceClient(model=self.config.url, token=self.config.api_token)
|
||||||
|
|
||||||
def _get_endpoint_info(self) -> Dict[str, Any]:
|
def _get_endpoint_info(self) -> Dict[str, Any]:
|
||||||
return {**self.client.get_endpoint_info(), "inference_url": self.config.url}
|
return {
|
||||||
|
**self.client.get_endpoint_info(),
|
||||||
|
"inference_url": self.config.url,
|
||||||
|
}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -68,7 +65,7 @@ class TGIAdapter(Inference):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise RuntimeError(f"Error initializing TGIAdapter: {e}")
|
raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -244,12 +241,15 @@ class InferenceEndpointAdapter(TGIAdapter):
|
||||||
"or 'endpoint_name'"
|
"or 'endpoint_name'"
|
||||||
)
|
)
|
||||||
if "/" not in hf_endpoint_name:
|
if "/" not in hf_endpoint_name:
|
||||||
hf_namespace: str = self.config.get_namespace()
|
hf_namespace: str = self.get_namespace()
|
||||||
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
|
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
|
||||||
else:
|
else:
|
||||||
endpoint_path = hf_endpoint_name
|
endpoint_path = hf_endpoint_name
|
||||||
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
|
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
|
||||||
|
|
||||||
|
def get_namespace(self) -> str:
|
||||||
|
return HfApi().whoami()["name"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> InferenceClient:
|
def client(self) -> InferenceClient:
|
||||||
return InferenceClient(model=self.inference_url, token=self.config.api_token)
|
return InferenceClient(model=self.inference_url, token=self.config.api_token)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue