Quick example illustrating get_request_provider_data

This commit is contained in:
Ashwin Bharambe 2024-09-23 15:38:14 -07:00
parent 9eb5ec3e4b
commit 8266c75246
4 changed files with 18 additions and 2 deletions

View file

@ -14,9 +14,10 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from .config import TGIImplConfig
from .config import TGIImplConfig, TGIRequestProviderData
class TGIAdapter(Inference):
@ -84,6 +85,11 @@ class TGIAdapter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request_provider_data = get_request_provider_data()
if request_provider_data is not None:
assert isinstance(request_provider_data, TGIRequestProviderData)
print(f"TGI API KEY: {request_provider_data.tgi_api_key}")
request = ChatCompletionRequest(
model=model,
messages=messages,