This commit is contained in:
Ashwin Bharambe 2025-11-05 14:59:33 -08:00
parent e61add29e0
commit 6115679679

View file

@ -31,6 +31,12 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def __init__(self, config: PassthroughImplConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
@ -53,8 +59,27 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
response.raise_for_status()
response_data = response.json()
# The response should be a list of Model objects
return [Model.model_validate(m) for m in response_data]
models_data = response_data["data"]
# Convert from OpenAI format to Llama Stack Model format
models = []
for model_data in models_data:
downstream_model_id = model_data["id"]
custom_metadata = model_data.get("custom_metadata", {})
# Prefix identifier with provider ID for local registry
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
model = Model(
identifier=local_identifier,
provider_id=self.__provider_id__,
provider_resource_id=downstream_model_id,
model_type=custom_metadata.get("model_type", "llm"),
metadata=custom_metadata,
)
models.append(model)
return models
async def should_refresh_models(self) -> bool:
"""Passthrough should refresh models since they come from downstream dynamically."""
@ -159,23 +184,22 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
if not data_str or data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
yield response_type.model_validate(data)
except Exception:
# Log and skip malformed chunks
continue
data = json.loads(data_str)
# Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
# but our Pydantic model may not accept null
if "choices" in data:
for choice in data["choices"]:
if choice.get("finish_reason") is None:
choice["finish_reason"] = ""
yield response_type.model_validate(data)
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(params.model)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
# params.model is already the provider_resource_id (router translated it)
request_params = params.model_dump(exclude_none=True)
return await self._make_request(
@ -189,12 +213,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(params.model)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
# params.model is already the provider_resource_id (router translated it)
request_params = params.model_dump(exclude_none=True)
return await self._make_request(
@ -208,12 +227,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
model_obj = await self.model_store.get_model(params.model)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
# params.model is already the provider_resource_id (router translated it)
request_params = params.model_dump(exclude_none=True)
return await self._make_request(