mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Quick example illustrating get_request_provider_data
This commit is contained in:
parent
9eb5ec3e4b
commit
8266c75246
4 changed files with 18 additions and 2 deletions
|
@ -68,7 +68,10 @@ class InferenceClient(Inference):
|
|||
"POST",
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-LlamaStack-ProviderData": json.dumps({"tgi_api_key": "1234"}),
|
||||
},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
|
|
|
@ -10,6 +10,12 @@ from llama_models.schema_utils import json_schema_type
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TGIRequestProviderData(BaseModel):
|
||||
# if there _is_ provider data, it must specify the API KEY
|
||||
# if you want it to be optional, use Optional[str]
|
||||
tgi_api_key: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(BaseModel):
|
||||
url: Optional[str] = Field(
|
||||
|
|
|
@ -14,9 +14,10 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
|||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import TGIImplConfig
|
||||
from .config import TGIImplConfig, TGIRequestProviderData
|
||||
|
||||
|
||||
class TGIAdapter(Inference):
|
||||
|
@ -84,6 +85,11 @@ class TGIAdapter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request_provider_data = get_request_provider_data()
|
||||
if request_provider_data is not None:
|
||||
assert isinstance(request_provider_data, TGIRequestProviderData)
|
||||
print(f"TGI API KEY: {request_provider_data.tgi_api_key}")
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -50,6 +50,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=["huggingface_hub"],
|
||||
module="llama_stack.providers.adapters.inference.tgi",
|
||||
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
|
||||
provider_data_validator="llama_stack.providers.adapters.inference.tgi.TGIRequestProviderData",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue