mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-22 04:27:52 +00:00
Merge branch 'main' into feat/add-url-to-paginated-response
This commit is contained in:
commit
b5047db685
24 changed files with 911 additions and 856 deletions
|
@ -180,6 +180,7 @@ def get_provider_registry(
|
|||
if provider_type_key in ret[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
|
|
|
@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
||||
# Check if the method has a concrete implementation (not just a protocol stub)
|
||||
# Find all classes in MRO that define this method
|
||||
method_owners = [cls for cls in mro if name in cls.__dict__]
|
||||
|
||||
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
|
||||
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
|
||||
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
|
|
|
@ -163,6 +163,9 @@ class InferenceRouter(Inference):
|
|||
messages: list[Message] | InterleavedContent,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> int | None:
|
||||
if not hasattr(self, "formatter") or self.formatter is None:
|
||||
return None
|
||||
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreDeleteResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
@ -108,7 +108,7 @@ class VectorIORouter(VectorIO):
|
|||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str | None = None,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
|
@ -151,8 +151,8 @@ class VectorIORouter(VectorIO):
|
|||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int = 20,
|
||||
order: str = "desc",
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
|
@ -239,10 +239,10 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int = 10,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: dict[str, Any] | None = None,
|
||||
rewrite_query: bool = False,
|
||||
) -> VectorStoreSearchResponse:
|
||||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
|
|
|
@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
|
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
if self.config.jwks is None:
|
||||
raise ValueError("JWKS is not configured")
|
||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
||||
headers = {}
|
||||
if self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
res = await client.get(self.config.jwks.uri, timeout=5)
|
||||
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue