From 489f74a70b3d39fc7ce574d46fee43d90b3bda1f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 6 Nov 2024 19:18:58 -0800 Subject: [PATCH] Allow simpler initialization of `RemoteProviderConfig`; fix issue in httpx client --- llama_stack/distribution/client.py | 15 +++++++++++---- llama_stack/providers/datatypes.py | 10 +++++++++- llama_stack/providers/tests/conftest.py | 12 ++++++++---- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index 613c90bd6..ce788a713 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -143,14 +143,21 @@ def create_api_client_class(protocol, additional_protocol) -> Type: else: data.update(convert(kwargs)) - return dict( + ret = dict( method=webmethod.method or "POST", url=url, - headers={"Content-Type": "application/json"}, - params=params, - json=data, + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, timeout=30, ) + if params: + ret["params"] = params + if data: + ret["json"] = data + + return ret # Add protocol methods to the wrapper for p in protocols: diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 69255fc5f..919507d11 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Any, List, Optional, Protocol +from urllib.parse import urlparse from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -145,13 +146,20 @@ Fully-qualified name of the module to import. The module is expected to have: class RemoteProviderConfig(BaseModel): host: str = "localhost" - port: int = 0 + port: Optional[int] = None protocol: str = "http" @property def url(self) -> str: + if self.port is None: + return f"{self.protocol}://{self.host}" return f"{self.protocol}://{self.host}:{self.port}" + @classmethod + def from_url(cls, url: str) -> "RemoteProviderConfig": + parsed = urlparse(url) + return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme) + @json_schema_type class RemoteProviderSpec(ProviderSpec): diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 11b0dcb45..2278e1a6c 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -25,15 +25,19 @@ class ProviderFixture(BaseModel): def remote_stack_fixture() -> ProviderFixture: + if url := os.getenv("REMOTE_STACK_URL", None): + config = RemoteProviderConfig.from_url(url) + else: + config = RemoteProviderConfig( + host=get_env_or_fail("REMOTE_STACK_HOST"), + port=int(get_env_or_fail("REMOTE_STACK_PORT")), + ) return ProviderFixture( providers=[ Provider( provider_id="remote", provider_type="remote", - config=RemoteProviderConfig( - host=get_env_or_fail("REMOTE_STACK_HOST"), - port=int(get_env_or_fail("REMOTE_STACK_PORT")), - ).model_dump(), + config=config.model_dump(), ) ], )