Allow simpler initialization of RemoteProviderConfig; fix issue in httpx client

This commit is contained in:
Ashwin Bharambe 2024-11-06 19:18:58 -08:00
parent 064d2a5287
commit 489f74a70b
3 changed files with 28 additions and 9 deletions

View file

@ -143,14 +143,21 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
else:
data.update(convert(kwargs))
return dict(
ret = dict(
method=webmethod.method or "POST",
url=url,
headers={"Content-Type": "application/json"},
params=params,
json=data,
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
timeout=30,
)
if params:
ret["params"] = params
if data:
ret["json"] = data
return ret
# Add protocol methods to the wrapper
for p in protocols:

View file

@ -6,6 +6,7 @@
from enum import Enum
from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -145,13 +146,20 @@ Fully-qualified name of the module to import. The module is expected to have:
class RemoteProviderConfig(BaseModel):
host: str = "localhost"
port: int = 0
port: Optional[int] = None
protocol: str = "http"
@property
def url(self) -> str:
if self.port is None:
return f"{self.protocol}://{self.host}"
return f"{self.protocol}://{self.host}:{self.port}"
@classmethod
def from_url(cls, url: str) -> "RemoteProviderConfig":
parsed = urlparse(url)
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
@json_schema_type
class RemoteProviderSpec(ProviderSpec):

View file

@ -25,15 +25,19 @@ class ProviderFixture(BaseModel):
def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
else:
config = RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
)
return ProviderFixture(
providers=[
Provider(
provider_id="remote",
provider_type="remote",
config=RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
).model_dump(),
config=config.model_dump(),
)
],
)