for agents API, provider data from the header is not parsed as for agents there is no provider_data_validator meta-reference implementation. Added Together data validator as the provider_data_validator for now. Did some code changes accordingly.

This commit is contained in:
Yogish Baliga 2024-09-27 14:59:24 -07:00
parent 208b861289
commit 572c01f454
5 changed files with 32 additions and 14 deletions

View file

@ -18,6 +18,7 @@ def get_request_provider_data() -> Any:
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]): def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
if not validator_classes: if not validator_classes:
return return

View file

@ -199,9 +199,12 @@ async def lifespan(app: FastAPI):
def create_dynamic_passthrough( def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None downstream_url: str,
provider_data_validators: List[str],
downstream_headers: Optional[Dict[str, str]] = None,
): ):
async def endpoint(request: Request): async def endpoint(request: Request):
set_request_provider_data(request.headers, provider_data_validators)
return await passthrough(request, downstream_url, downstream_headers) return await passthrough(request, downstream_url, downstream_headers)
return endpoint return endpoint
@ -223,7 +226,6 @@ def create_dynamic_typed_route(
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__) await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validators) set_request_provider_data(request.headers, provider_data_validators)
async def sse_generator(event_gen): async def sse_generator(event_gen):
@ -446,6 +448,16 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl = impls[api] impl = impls[api]
provider_spec = specs[api] provider_spec = specs[api]
validators = []
if isinstance(provider_spec, AutoRoutedProviderSpec):
inner_specs = specs[provider_spec.routing_table_api].inner_specs
for spec in inner_specs:
if spec.provider_data_validator:
validators.append(spec.provider_data_validator)
elif not isinstance(provider_spec, RoutingTableProviderSpec):
if provider_spec.provider_data_validator:
validators.append(provider_spec.provider_data_validator)
if ( if (
isinstance(provider_spec, RemoteProviderSpec) isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None and provider_spec.adapter is None
@ -453,7 +465,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
for endpoint in endpoints: for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url) create_dynamic_passthrough(url, validators)
) )
else: else:
for endpoint in endpoints: for endpoint in endpoints:
@ -465,16 +477,6 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)
validators = []
if isinstance(provider_spec, AutoRoutedProviderSpec):
inner_specs = specs[provider_spec.routing_table_api].inner_specs
for spec in inner_specs:
if spec.provider_data_validator:
validators.append(spec.provider_data_validator)
elif not isinstance(provider_spec, RoutingTableProviderSpec):
if provider_spec.provider_data_validator:
validators.append(provider_spec.provider_data_validator)
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
@ -61,6 +60,20 @@ class TogetherInferenceAdapter(Inference):
role = "tool" role = "tool"
else: else:
role = message.role role = message.role
if role == "user" and type(message.content) == list:
contents = []
for content in message.content:
if type(content) == str:
contents.append({"type": "text", "text": content})
elif type(content) == ImageMedia:
contents.append(
{
"type": "image_url",
"image_url": {"url": content.image.uri},
}
)
message.content = contents
together_messages.append({"role": role, "content": message.content}) together_messages.append({"role": role, "content": message.content})
return together_messages return together_messages

View file

@ -541,6 +541,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(

View file

@ -29,6 +29,7 @@ def available_providers() -> List[ProviderSpec]:
Api.safety, Api.safety,
Api.memory, Api.memory,
], ],
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
), ),
remote_provider_spec( remote_provider_spec(
api=Api.agents, api=Api.agents,