This commit is contained in:
Ashwin Bharambe 2025-11-05 14:59:33 -08:00
parent e61add29e0
commit 6115679679

View file

@ -31,6 +31,12 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def __init__(self, config: PassthroughImplConfig) -> None: def __init__(self, config: PassthroughImplConfig) -> None:
self.config = config self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
@ -53,8 +59,27 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
response.raise_for_status() response.raise_for_status()
response_data = response.json() response_data = response.json()
# The response should be a list of Model objects models_data = response_data["data"]
return [Model.model_validate(m) for m in response_data]
# Convert from OpenAI format to Llama Stack Model format
models = []
for model_data in models_data:
downstream_model_id = model_data["id"]
custom_metadata = model_data.get("custom_metadata", {})
# Prefix identifier with provider ID for local registry
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
model = Model(
identifier=local_identifier,
provider_id=self.__provider_id__,
provider_resource_id=downstream_model_id,
model_type=custom_metadata.get("model_type", "llm"),
metadata=custom_metadata,
)
models.append(model)
return models
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:
"""Passthrough should refresh models since they come from downstream dynamically.""" """Passthrough should refresh models since they come from downstream dynamically."""
@ -159,23 +184,22 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
if not data_str or data_str == "[DONE]": if not data_str or data_str == "[DONE]":
continue continue
try: data = json.loads(data_str)
data = json.loads(data_str)
yield response_type.model_validate(data) # Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
except Exception: # but our Pydantic model may not accept null
# Log and skip malformed chunks if "choices" in data:
continue for choice in data["choices"]:
if choice.get("finish_reason") is None:
choice["finish_reason"] = ""
yield response_type.model_validate(data)
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(params.model) # params.model is already the provider_resource_id (router translated it)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
return await self._make_request( return await self._make_request(
@ -189,12 +213,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
self, self,
params: OpenAIChatCompletionRequestWithExtraBody, params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(params.model) # params.model is already the provider_resource_id (router translated it)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
return await self._make_request( return await self._make_request(
@ -208,12 +227,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
self, self,
params: OpenAIEmbeddingsRequestWithExtraBody, params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
model_obj = await self.model_store.get_model(params.model) # params.model is already the provider_resource_id (router translated it)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
return await self._make_request( return await self._make_request(