mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fixes
This commit is contained in:
parent
e61add29e0
commit
6115679679
1 changed files with 40 additions and 26 deletions
|
|
@ -31,6 +31,12 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -53,8 +59,27 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
# The response should be a list of Model objects
|
||||
return [Model.model_validate(m) for m in response_data]
|
||||
models_data = response_data["data"]
|
||||
|
||||
# Convert from OpenAI format to Llama Stack Model format
|
||||
models = []
|
||||
for model_data in models_data:
|
||||
downstream_model_id = model_data["id"]
|
||||
custom_metadata = model_data.get("custom_metadata", {})
|
||||
|
||||
# Prefix identifier with provider ID for local registry
|
||||
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
|
||||
|
||||
model = Model(
|
||||
identifier=local_identifier,
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=downstream_model_id,
|
||||
model_type=custom_metadata.get("model_type", "llm"),
|
||||
metadata=custom_metadata,
|
||||
)
|
||||
models.append(model)
|
||||
|
||||
return models
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
"""Passthrough should refresh models since they come from downstream dynamically."""
|
||||
|
|
@ -159,23 +184,22 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
if not data_str or data_str == "[DONE]":
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
yield response_type.model_validate(data)
|
||||
except Exception:
|
||||
# Log and skip malformed chunks
|
||||
continue
|
||||
data = json.loads(data_str)
|
||||
|
||||
# Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
|
||||
# but our Pydantic model may not accept null
|
||||
if "choices" in data:
|
||||
for choice in data["choices"]:
|
||||
if choice.get("finish_reason") is None:
|
||||
choice["finish_reason"] = ""
|
||||
|
||||
yield response_type.model_validate(data)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
# Create a copy with the provider's model ID
|
||||
params = params.model_copy()
|
||||
params.model = model_obj.provider_resource_id
|
||||
|
||||
# params.model is already the provider_resource_id (router translated it)
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await self._make_request(
|
||||
|
|
@ -189,12 +213,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
# Create a copy with the provider's model ID
|
||||
params = params.model_copy()
|
||||
params.model = model_obj.provider_resource_id
|
||||
|
||||
# params.model is already the provider_resource_id (router translated it)
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await self._make_request(
|
||||
|
|
@ -208,12 +227,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
# Create a copy with the provider's model ID
|
||||
params = params.model_copy()
|
||||
params.model = model_obj.provider_resource_id
|
||||
|
||||
# params.model is already the provider_resource_id (router translated it)
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await self._make_request(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue