From 6115679679509db5621620d80bdfbf3346aade7b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 5 Nov 2025 14:59:33 -0800 Subject: [PATCH] fixes --- .../inference/passthrough/passthrough.py | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py index 81581a238..d991a663b 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -31,6 +31,12 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): def __init__(self, config: PassthroughImplConfig) -> None: self.config = config + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + async def unregister_model(self, model_id: str) -> None: pass @@ -53,8 +59,27 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): response.raise_for_status() response_data = response.json() - # The response should be a list of Model objects - return [Model.model_validate(m) for m in response_data] + models_data = response_data["data"] + + # Convert from OpenAI format to Llama Stack Model format + models = [] + for model_data in models_data: + downstream_model_id = model_data["id"] + custom_metadata = model_data.get("custom_metadata", {}) + + # Prefix identifier with provider ID for local registry + local_identifier = f"{self.__provider_id__}/{downstream_model_id}" + + model = Model( + identifier=local_identifier, + provider_id=self.__provider_id__, + provider_resource_id=downstream_model_id, + model_type=custom_metadata.get("model_type", "llm"), + metadata=custom_metadata, + ) + models.append(model) + + return models async def should_refresh_models(self) -> bool: """Passthrough should refresh models since they come from downstream dynamically.""" @@ -159,23 +184,22 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): if not data_str or data_str == "[DONE]": continue - try: - data = json.loads(data_str) - yield response_type.model_validate(data) - except Exception: - # Log and skip malformed chunks - continue + data = json.loads(data_str) + + # Fix OpenAI compatibility: finish_reason can be null in intermediate chunks + # but our Pydantic model may not accept null + if "choices" in data: + for choice in data["choices"]: + if choice.get("finish_reason") is None: + choice["finish_reason"] = "" + + yield response_type.model_validate(data) async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(params.model) - - # Create a copy with the provider's model ID - params = params.model_copy() - params.model = model_obj.provider_resource_id - + # params.model is already the provider_resource_id (router translated it) request_params = params.model_dump(exclude_none=True) return await self._make_request( @@ -189,12 +213,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(params.model) - - # Create a copy with the provider's model ID - params = params.model_copy() - params.model = model_obj.provider_resource_id - + # params.model is already the provider_resource_id (router translated it) request_params = params.model_dump(exclude_none=True) return await self._make_request( @@ -208,12 +227,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: - model_obj = await self.model_store.get_model(params.model) - - # Create a copy with the provider's model ID - params = params.model_copy() - params.model = model_obj.provider_resource_id - + # params.model is already the provider_resource_id (router translated it) request_params = params.model_dump(exclude_none=True) return await self._make_request(