Merge branch 'main' into feat/add-url-to-paginated-response

This commit is contained in:
Rohan Awhad 2025-06-13 13:07:45 -04:00 committed by GitHub
commit b5047db685
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 911 additions and 856 deletions

View file

@ -180,6 +180,7 @@ def get_provider_registry(
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err

View file

@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
# Check if the method has a concrete implementation (not just a protocol stub)
# Find all classes in MRO that define this method
method_owners = [cls for cls in mro if name in cls.__dict__]
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -163,6 +163,9 @@ class InferenceRouter(Inference):
messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None:
if not hasattr(self, "formatter") or self.formatter is None:
return None
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:

View file

@ -17,7 +17,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -108,7 +108,7 @@ class VectorIORouter(VectorIO):
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
name: str | None = None,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
@ -151,8 +151,8 @@ class VectorIORouter(VectorIO):
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
@ -239,10 +239,10 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
) -> VectorStoreSearchResponse:
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)

View file

@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}