forked from phoenix-oss/llama-stack-mirror
Allow simpler initialization of RemoteProviderConfig
; fix issue in httpx client
This commit is contained in:
parent
064d2a5287
commit
489f74a70b
3 changed files with 28 additions and 9 deletions
|
@ -143,14 +143,21 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
else:
|
||||
data.update(convert(kwargs))
|
||||
|
||||
return dict(
|
||||
ret = dict(
|
||||
method=webmethod.method or "POST",
|
||||
url=url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
params=params,
|
||||
json=data,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if params:
|
||||
ret["params"] = params
|
||||
if data:
|
||||
ret["json"] = data
|
||||
|
||||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for p in protocols:
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -145,13 +146,20 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 0
|
||||
port: Optional[int] = None
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
if self.port is None:
|
||||
return f"{self.protocol}://{self.host}"
|
||||
return f"{self.protocol}://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str) -> "RemoteProviderConfig":
|
||||
parsed = urlparse(url)
|
||||
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
|
|
|
@ -25,15 +25,19 @@ class ProviderFixture(BaseModel):
|
|||
|
||||
|
||||
def remote_stack_fixture() -> ProviderFixture:
|
||||
if url := os.getenv("REMOTE_STACK_URL", None):
|
||||
config = RemoteProviderConfig.from_url(url)
|
||||
else:
|
||||
config = RemoteProviderConfig(
|
||||
host=get_env_or_fail("REMOTE_STACK_HOST"),
|
||||
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
|
||||
)
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="remote",
|
||||
provider_type="remote",
|
||||
config=RemoteProviderConfig(
|
||||
host=get_env_or_fail("REMOTE_STACK_HOST"),
|
||||
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
|
||||
).model_dump(),
|
||||
config=config.model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue