mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 18:13:44 +00:00
fixes
This commit is contained in:
parent
e61add29e0
commit
6115679679
1 changed files with 40 additions and 26 deletions
|
|
@ -31,6 +31,12 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
def __init__(self, config: PassthroughImplConfig) -> None:
|
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -53,8 +59,27 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
# The response should be a list of Model objects
|
models_data = response_data["data"]
|
||||||
return [Model.model_validate(m) for m in response_data]
|
|
||||||
|
# Convert from OpenAI format to Llama Stack Model format
|
||||||
|
models = []
|
||||||
|
for model_data in models_data:
|
||||||
|
downstream_model_id = model_data["id"]
|
||||||
|
custom_metadata = model_data.get("custom_metadata", {})
|
||||||
|
|
||||||
|
# Prefix identifier with provider ID for local registry
|
||||||
|
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
|
||||||
|
|
||||||
|
model = Model(
|
||||||
|
identifier=local_identifier,
|
||||||
|
provider_id=self.__provider_id__,
|
||||||
|
provider_resource_id=downstream_model_id,
|
||||||
|
model_type=custom_metadata.get("model_type", "llm"),
|
||||||
|
metadata=custom_metadata,
|
||||||
|
)
|
||||||
|
models.append(model)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
"""Passthrough should refresh models since they come from downstream dynamically."""
|
"""Passthrough should refresh models since they come from downstream dynamically."""
|
||||||
|
|
@ -159,23 +184,22 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
if not data_str or data_str == "[DONE]":
|
if not data_str or data_str == "[DONE]":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(data_str)
|
data = json.loads(data_str)
|
||||||
|
|
||||||
|
# Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
|
||||||
|
# but our Pydantic model may not accept null
|
||||||
|
if "choices" in data:
|
||||||
|
for choice in data["choices"]:
|
||||||
|
if choice.get("finish_reason") is None:
|
||||||
|
choice["finish_reason"] = ""
|
||||||
|
|
||||||
yield response_type.model_validate(data)
|
yield response_type.model_validate(data)
|
||||||
except Exception:
|
|
||||||
# Log and skip malformed chunks
|
|
||||||
continue
|
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
# params.model is already the provider_resource_id (router translated it)
|
||||||
|
|
||||||
# Create a copy with the provider's model ID
|
|
||||||
params = params.model_copy()
|
|
||||||
params.model = model_obj.provider_resource_id
|
|
||||||
|
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
return await self._make_request(
|
return await self._make_request(
|
||||||
|
|
@ -189,12 +213,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
self,
|
self,
|
||||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
# params.model is already the provider_resource_id (router translated it)
|
||||||
|
|
||||||
# Create a copy with the provider's model ID
|
|
||||||
params = params.model_copy()
|
|
||||||
params.model = model_obj.provider_resource_id
|
|
||||||
|
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
return await self._make_request(
|
return await self._make_request(
|
||||||
|
|
@ -208,12 +227,7 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
self,
|
self,
|
||||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
# params.model is already the provider_resource_id (router translated it)
|
||||||
|
|
||||||
# Create a copy with the provider's model ID
|
|
||||||
params = params.model_copy()
|
|
||||||
params.model = model_obj.provider_resource_id
|
|
||||||
|
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
return await self._make_request(
|
return await self._make_request(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue