Allow simpler initialization of RemoteProviderConfig; fix issue in httpx client

This commit is contained in:
Ashwin Bharambe 2024-11-06 19:18:58 -08:00
parent 064d2a5287
commit 489f74a70b
3 changed files with 28 additions and 9 deletions

View file

@ -143,14 +143,21 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
else: else:
data.update(convert(kwargs)) data.update(convert(kwargs))
return dict( ret = dict(
method=webmethod.method or "POST", method=webmethod.method or "POST",
url=url, url=url,
headers={"Content-Type": "application/json"}, headers={
params=params, "Accept": "application/json",
json=data, "Content-Type": "application/json",
},
timeout=30, timeout=30,
) )
if params:
ret["params"] = params
if data:
ret["json"] = data
return ret
# Add protocol methods to the wrapper # Add protocol methods to the wrapper
for p in protocols: for p in protocols:

View file

@ -6,6 +6,7 @@
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Protocol from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -145,13 +146,20 @@ Fully-qualified name of the module to import. The module is expected to have:
class RemoteProviderConfig(BaseModel): class RemoteProviderConfig(BaseModel):
host: str = "localhost" host: str = "localhost"
port: int = 0 port: Optional[int] = None
protocol: str = "http" protocol: str = "http"
@property @property
def url(self) -> str: def url(self) -> str:
if self.port is None:
return f"{self.protocol}://{self.host}"
return f"{self.protocol}://{self.host}:{self.port}" return f"{self.protocol}://{self.host}:{self.port}"
@classmethod
def from_url(cls, url: str) -> "RemoteProviderConfig":
parsed = urlparse(url)
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):

View file

@ -25,15 +25,19 @@ class ProviderFixture(BaseModel):
def remote_stack_fixture() -> ProviderFixture: def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
else:
config = RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
)
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="remote", provider_id="remote",
provider_type="remote", provider_type="remote",
config=RemoteProviderConfig( config=config.model_dump(),
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
).model_dump(),
) )
], ],
) )