ollama remote adapter works

This commit is contained in:
Ashwin Bharambe 2024-08-28 06:51:07 -07:00
parent 2076d2b6db
commit 2a1552a5eb
14 changed files with 196 additions and 128 deletions

View file

@ -38,11 +38,9 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution_spec
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
def is_async_iterator_type(typ):
@ -249,9 +247,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
provider_specs = topological_sort(provider_specs.values())
impls = {}
for provider_spec in provider_specs:
@ -261,16 +261,10 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
f"Could not find provider_spec config for {api}. Please add it to the config"
)
deps = {api: impls[api] for api in provider_spec.api_dependencies}
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec,
provider_config,
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
@ -279,22 +273,34 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI()
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
all_providers = api_providers()
for provider_spec in dist.provider_specs.values():
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
for provider_spec in provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec):
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(