mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
for agents API, provider data from the header is not parsed as for agents there is no provider_data_validator meta-reference implementation. Added Together data validator as the provider_data_validator for now. Did some code changes accordingly.
This commit is contained in:
parent
208b861289
commit
572c01f454
5 changed files with 32 additions and 14 deletions
|
@ -18,6 +18,7 @@ def get_request_provider_data() -> Any:
|
|||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
|
||||
|
||||
if not validator_classes:
|
||||
return
|
||||
|
||||
|
|
|
@ -199,9 +199,12 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
downstream_url: str,
|
||||
provider_data_validators: List[str],
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
set_request_provider_data(request.headers, provider_data_validators)
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
@ -223,7 +226,6 @@ def create_dynamic_typed_route(
|
|||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validators)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
|
@ -446,6 +448,16 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
validators = []
|
||||
if isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
inner_specs = specs[provider_spec.routing_table_api].inner_specs
|
||||
for spec in inner_specs:
|
||||
if spec.provider_data_validator:
|
||||
validators.append(spec.provider_data_validator)
|
||||
elif not isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
if provider_spec.provider_data_validator:
|
||||
validators.append(provider_spec.provider_data_validator)
|
||||
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
|
@ -453,7 +465,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
create_dynamic_passthrough(url, validators)
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
|
@ -465,16 +477,6 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
validators = []
|
||||
if isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
inner_specs = specs[provider_spec.routing_table_api].inner_specs
|
||||
for spec in inner_specs:
|
||||
if spec.provider_data_validator:
|
||||
validators.append(spec.provider_data_validator)
|
||||
elif not isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
if provider_spec.provider_data_validator:
|
||||
validators.append(provider_spec.provider_data_validator)
|
||||
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
@ -61,6 +60,20 @@ class TogetherInferenceAdapter(Inference):
|
|||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
|
||||
if role == "user" and type(message.content) == list:
|
||||
contents = []
|
||||
for content in message.content:
|
||||
if type(content) == str:
|
||||
contents.append({"type": "text", "text": content})
|
||||
elif type(content) == ImageMedia:
|
||||
contents.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": content.image.uri},
|
||||
}
|
||||
)
|
||||
message.content = contents
|
||||
together_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return together_messages
|
||||
|
|
|
@ -541,6 +541,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
|
|
@ -29,6 +29,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.safety,
|
||||
Api.memory,
|
||||
],
|
||||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.agents,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue