mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 09:05:37 +00:00 
			
		
		
		
	We would like to rename the term `template` to `distribution`. To prepare for that, this is a precursor. cc @leseb
		
			
				
	
	
		
			189 lines
		
	
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			189 lines
		
	
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import inspect
 | |
| import json
 | |
| import sys
 | |
| from collections.abc import AsyncIterator
 | |
| from enum import Enum
 | |
| from typing import Any, Union, get_args, get_origin
 | |
| 
 | |
| import httpx
 | |
| from pydantic import BaseModel, parse_obj_as
 | |
| from termcolor import cprint
 | |
| 
 | |
| from llama_stack.apis.version import LLAMA_STACK_API_VERSION
 | |
| from llama_stack.providers.datatypes import RemoteProviderConfig
 | |
| 
 | |
| _CLIENT_CLASSES = {}
 | |
| 
 | |
| 
 | |
| async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
 | |
|     client_class = create_api_client_class(protocol)
 | |
|     impl = client_class(config.url)
 | |
|     await impl.initialize()
 | |
|     return impl
 | |
| 
 | |
| 
 | |
| def create_api_client_class(protocol) -> type:
 | |
|     if protocol in _CLIENT_CLASSES:
 | |
|         return _CLIENT_CLASSES[protocol]
 | |
| 
 | |
|     class APIClient:
 | |
|         def __init__(self, base_url: str):
 | |
|             print(f"({protocol.__name__}) Connecting to {base_url}")
 | |
|             self.base_url = base_url.rstrip("/")
 | |
|             self.routes = {}
 | |
| 
 | |
|             # Store routes for this protocol
 | |
|             for name, method in inspect.getmembers(protocol):
 | |
|                 if hasattr(method, "__webmethod__"):
 | |
|                     sig = inspect.signature(method)
 | |
|                     self.routes[name] = (method.__webmethod__, sig)
 | |
| 
 | |
|         async def initialize(self):
 | |
|             pass
 | |
| 
 | |
|         async def shutdown(self):
 | |
|             pass
 | |
| 
 | |
|         async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
 | |
|             assert method_name in self.routes, f"Unknown endpoint: {method_name}"
 | |
| 
 | |
|             # TODO: make this more precise, same thing needs to happen in server.py
 | |
|             is_streaming = kwargs.get("stream", False)
 | |
|             if is_streaming:
 | |
|                 return self._call_streaming(method_name, *args, **kwargs)
 | |
|             else:
 | |
|                 return await self._call_non_streaming(method_name, *args, **kwargs)
 | |
| 
 | |
|         async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
 | |
|             _, sig = self.routes[method_name]
 | |
| 
 | |
|             if sig.return_annotation is None:
 | |
|                 return_type = None
 | |
|             else:
 | |
|                 return_type = extract_non_async_iterator_type(sig.return_annotation)
 | |
|                 assert return_type, f"Could not extract return type for {sig.return_annotation}"
 | |
| 
 | |
|             async with httpx.AsyncClient() as client:
 | |
|                 params = self.httpx_request_params(method_name, *args, **kwargs)
 | |
|                 response = await client.request(**params)
 | |
|                 response.raise_for_status()
 | |
| 
 | |
|                 j = response.json()
 | |
|                 if j is None:
 | |
|                     return None
 | |
|                 # print(f"({protocol.__name__}) Returning {j}, type {return_type}")
 | |
|                 return parse_obj_as(return_type, j)
 | |
| 
 | |
|         async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
 | |
|             webmethod, sig = self.routes[method_name]
 | |
| 
 | |
|             return_type = extract_async_iterator_type(sig.return_annotation)
 | |
|             assert return_type, f"Could not extract return type for {sig.return_annotation}"
 | |
| 
 | |
|             async with httpx.AsyncClient() as client:
 | |
|                 params = self.httpx_request_params(method_name, *args, **kwargs)
 | |
|                 async with client.stream(**params) as response:
 | |
|                     response.raise_for_status()
 | |
| 
 | |
|                     async for line in response.aiter_lines():
 | |
|                         if line.startswith("data:"):
 | |
|                             data = line[len("data: ") :]
 | |
|                             try:
 | |
|                                 data = json.loads(data)
 | |
|                                 if "error" in data:
 | |
|                                     cprint(data, color="red", file=sys.stderr)
 | |
|                                     continue
 | |
| 
 | |
|                                 yield parse_obj_as(return_type, data)
 | |
|                             except Exception as e:
 | |
|                                 cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
 | |
|                                 cprint(data, color="red", file=sys.stderr)
 | |
| 
 | |
|         def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
 | |
|             webmethod, sig = self.routes[method_name]
 | |
| 
 | |
|             parameters = list(sig.parameters.values())[1:]  # skip `self`
 | |
|             for i, param in enumerate(parameters):
 | |
|                 if i >= len(args):
 | |
|                     break
 | |
|                 kwargs[param.name] = args[i]
 | |
| 
 | |
|             url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
 | |
| 
 | |
|             def convert(value):
 | |
|                 if isinstance(value, list):
 | |
|                     return [convert(v) for v in value]
 | |
|                 elif isinstance(value, dict):
 | |
|                     return {k: convert(v) for k, v in value.items()}
 | |
|                 elif isinstance(value, BaseModel):
 | |
|                     return json.loads(value.model_dump_json())
 | |
|                 elif isinstance(value, Enum):
 | |
|                     return value.value
 | |
|                 else:
 | |
|                     return value
 | |
| 
 | |
|             params = {}
 | |
|             data = {}
 | |
|             if webmethod.method == "GET":
 | |
|                 params.update(kwargs)
 | |
|             else:
 | |
|                 data.update(convert(kwargs))
 | |
| 
 | |
|             ret = dict(
 | |
|                 method=webmethod.method or "POST",
 | |
|                 url=url,
 | |
|                 headers={
 | |
|                     "Accept": "application/json",
 | |
|                     "Content-Type": "application/json",
 | |
|                 },
 | |
|                 timeout=30,
 | |
|             )
 | |
|             if params:
 | |
|                 ret["params"] = params
 | |
|             if data:
 | |
|                 ret["json"] = data
 | |
| 
 | |
|             return ret
 | |
| 
 | |
|     # Add protocol methods to the wrapper
 | |
|     for name, method in inspect.getmembers(protocol):
 | |
|         if hasattr(method, "__webmethod__"):
 | |
| 
 | |
|             async def method_impl(self, *args, method_name=name, **kwargs):
 | |
|                 return await self.__acall__(method_name, *args, **kwargs)
 | |
| 
 | |
|             method_impl.__name__ = name
 | |
|             method_impl.__qualname__ = f"APIClient.{name}"
 | |
|             method_impl.__signature__ = inspect.signature(method)
 | |
|             setattr(APIClient, name, method_impl)
 | |
| 
 | |
|     # Name the class after the protocol
 | |
|     APIClient.__name__ = f"{protocol.__name__}Client"
 | |
|     _CLIENT_CLASSES[protocol] = APIClient
 | |
|     return APIClient
 | |
| 
 | |
| 
 | |
| # not quite general these methods are
 | |
| def extract_non_async_iterator_type(type_hint):
 | |
|     if get_origin(type_hint) is Union:
 | |
|         args = get_args(type_hint)
 | |
|         for arg in args:
 | |
|             if not issubclass(get_origin(arg) or arg, AsyncIterator):
 | |
|                 return arg
 | |
|     return type_hint
 | |
| 
 | |
| 
 | |
| def extract_async_iterator_type(type_hint):
 | |
|     if get_origin(type_hint) is Union:
 | |
|         args = get_args(type_hint)
 | |
|         for arg in args:
 | |
|             if issubclass(get_origin(arg) or arg, AsyncIterator):
 | |
|                 inner_args = get_args(arg)
 | |
|                 return inner_args[0]
 | |
|     return None
 |