diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 27b8b531f..19111374d 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -18,6 +18,7 @@ def get_request_provider_data() -> Any: def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]): + if not validator_classes: return diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 7a3e6276c..06e336741 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -199,9 +199,12 @@ async def lifespan(app: FastAPI): def create_dynamic_passthrough( - downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None + downstream_url: str, + provider_data_validators: List[str], + downstream_headers: Optional[Dict[str, str]] = None, ): async def endpoint(request: Request): + set_request_provider_data(request.headers, provider_data_validators) return await passthrough(request, downstream_url, downstream_headers) return endpoint @@ -223,7 +226,6 @@ def create_dynamic_typed_route( async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) - set_request_provider_data(request.headers, provider_data_validators) async def sse_generator(event_gen): @@ -446,6 +448,16 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): impl = impls[api] provider_spec = specs[api] + validators = [] + if isinstance(provider_spec, AutoRoutedProviderSpec): + inner_specs = specs[provider_spec.routing_table_api].inner_specs + for spec in inner_specs: + if spec.provider_data_validator: + validators.append(spec.provider_data_validator) + elif not isinstance(provider_spec, RoutingTableProviderSpec): + if provider_spec.provider_data_validator: + validators.append(provider_spec.provider_data_validator) + if ( isinstance(provider_spec, RemoteProviderSpec) and provider_spec.adapter is None @@ -453,7 +465,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): for endpoint in endpoints: url = impl.__provider_config__.url.rstrip("/") + endpoint.route getattr(app, endpoint.method)(endpoint.route)( - create_dynamic_passthrough(url) + create_dynamic_passthrough(url, validators) ) else: for endpoint in endpoints: @@ -465,16 +477,6 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): impl_method = getattr(impl, endpoint.name) - validators = [] - if isinstance(provider_spec, AutoRoutedProviderSpec): - inner_specs = specs[provider_spec.routing_table_api].inner_specs - for spec in inner_specs: - if spec.provider_data_validator: - validators.append(spec.provider_data_validator) - elif not isinstance(provider_spec, RoutingTableProviderSpec): - if provider_spec.provider_data_validator: - validators.append(provider_spec.provider_data_validator) - getattr(app, endpoint.method)(endpoint.route, response_model=None)( create_dynamic_typed_route( impl_method, diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index cafca3fdf..92b3418f8 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat @@ -61,6 +60,20 @@ class TogetherInferenceAdapter(Inference): role = "tool" else: role = message.role + + if role == "user" and type(message.content) == list: + contents = [] + for content in message.content: + if type(content) == str: + contents.append({"type": "text", "text": content}) + elif type(content) == ImageMedia: + contents.append( + { + "type": "image_url", + "image_url": {"url": content.image.uri}, + } + ) + message.content = contents together_messages.append({"role": role, "content": message.content}) return together_messages diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 9db6b79b5..c828ac51b 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -541,6 +541,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 16a872572..bd58a6952 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -29,6 +29,7 @@ def available_providers() -> List[ProviderSpec]: Api.safety, Api.memory, ], + provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", ), remote_provider_spec( api=Api.agents,