Quick example illustrating get_request_provider_data

This commit is contained in:
Ashwin Bharambe 2024-09-23 15:38:14 -07:00
parent 9eb5ec3e4b
commit 8266c75246
4 changed files with 18 additions and 2 deletions

View file

@ -68,7 +68,10 @@ class InferenceClient(Inference):
"POST", "POST",
f"{self.base_url}/inference/chat_completion", f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request), json=encodable_dict(request),
headers={"Content-Type": "application/json"}, headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"tgi_api_key": "1234"}),
},
timeout=20, timeout=20,
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:

View file

@ -10,6 +10,12 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class TGIRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str]
tgi_api_key: str
@json_schema_type @json_schema_type
class TGIImplConfig(BaseModel): class TGIImplConfig(BaseModel):
url: Optional[str] = Field( url: Optional[str] = Field(

View file

@ -14,9 +14,10 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import TGIImplConfig from .config import TGIImplConfig, TGIRequestProviderData
class TGIAdapter(Inference): class TGIAdapter(Inference):
@ -84,6 +85,11 @@ class TGIAdapter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request_provider_data = get_request_provider_data()
if request_provider_data is not None:
assert isinstance(request_provider_data, TGIRequestProviderData)
print(f"TGI API KEY: {request_provider_data.tgi_api_key}")
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -50,6 +50,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["huggingface_hub"], pip_packages=["huggingface_hub"],
module="llama_stack.providers.adapters.inference.tgi", module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
provider_data_validator="llama_stack.providers.adapters.inference.tgi.TGIRequestProviderData",
), ),
), ),
remote_provider_spec( remote_provider_spec(