From 8266c75246f19c86fcda4378539bfb500d618151 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 23 Sep 2024 15:38:14 -0700 Subject: [PATCH] Quick example illustrating `get_request_provider_data` --- llama_stack/apis/inference/client.py | 5 ++++- llama_stack/providers/adapters/inference/tgi/config.py | 6 ++++++ llama_stack/providers/adapters/inference/tgi/tgi.py | 8 +++++++- llama_stack/providers/registry/inference.py | 1 + 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 4df138841..b87a0dcaf 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -68,7 +68,10 @@ class InferenceClient(Inference): "POST", f"{self.base_url}/inference/chat_completion", json=encodable_dict(request), - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + "X-LlamaStack-ProviderData": json.dumps({"tgi_api_key": "1234"}), + }, timeout=20, ) as response: if response.status_code != 200: diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py index a0135dfdd..7dbe40e02 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/adapters/inference/tgi/config.py @@ -10,6 +10,12 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +class TGIRequestProviderData(BaseModel): + # if there _is_ provider data, it must specify the API KEY + # if you want it to be optional, use Optional[str] + tgi_api_key: str + + @json_schema_type class TGIImplConfig(BaseModel): url: Optional[str] = Field( diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 6a385896d..aa895d4d1 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -14,9 +14,10 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.request_headers import get_request_provider_data from llama_stack.providers.utils.inference.prepare_messages import prepare_messages -from .config import TGIImplConfig +from .config import TGIImplConfig, TGIRequestProviderData class TGIAdapter(Inference): @@ -84,6 +85,11 @@ class TGIAdapter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + request_provider_data = get_request_provider_data() + if request_provider_data is not None: + assert isinstance(request_provider_data, TGIRequestProviderData) + print(f"TGI API KEY: {request_provider_data.tgi_api_key}") + request = ChatCompletionRequest( model=model, messages=messages, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index e862c559f..3f432953a 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -50,6 +50,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=["huggingface_hub"], module="llama_stack.providers.adapters.inference.tgi", config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", + provider_data_validator="llama_stack.providers.adapters.inference.tgi.TGIRequestProviderData", ), ), remote_provider_spec(