mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Quick example illustrating get_request_provider_data
This commit is contained in:
parent
9eb5ec3e4b
commit
8266c75246
4 changed files with 18 additions and 2 deletions
|
@ -68,7 +68,10 @@ class InferenceClient(Inference):
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/inference/chat_completion",
|
f"{self.base_url}/inference/chat_completion",
|
||||||
json=encodable_dict(request),
|
json=encodable_dict(request),
|
||||||
headers={"Content-Type": "application/json"},
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-LlamaStack-ProviderData": json.dumps({"tgi_api_key": "1234"}),
|
||||||
|
},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|
|
@ -10,6 +10,12 @@ from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TGIRequestProviderData(BaseModel):
|
||||||
|
# if there _is_ provider data, it must specify the API KEY
|
||||||
|
# if you want it to be optional, use Optional[str]
|
||||||
|
tgi_api_key: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TGIImplConfig(BaseModel):
|
class TGIImplConfig(BaseModel):
|
||||||
url: Optional[str] = Field(
|
url: Optional[str] = Field(
|
||||||
|
|
|
@ -14,9 +14,10 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import TGIImplConfig
|
from .config import TGIImplConfig, TGIRequestProviderData
|
||||||
|
|
||||||
|
|
||||||
class TGIAdapter(Inference):
|
class TGIAdapter(Inference):
|
||||||
|
@ -84,6 +85,11 @@ class TGIAdapter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
request_provider_data = get_request_provider_data()
|
||||||
|
if request_provider_data is not None:
|
||||||
|
assert isinstance(request_provider_data, TGIRequestProviderData)
|
||||||
|
print(f"TGI API KEY: {request_provider_data.tgi_api_key}")
|
||||||
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -50,6 +50,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=["huggingface_hub"],
|
pip_packages=["huggingface_hub"],
|
||||||
module="llama_stack.providers.adapters.inference.tgi",
|
module="llama_stack.providers.adapters.inference.tgi",
|
||||||
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
|
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.adapters.inference.tgi.TGIRequestProviderData",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue