mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
[bugfix] fix broken vision inference, change serialization for bytes (#693)
# What does this PR do? - vision inference via image as binary bytes fails with serialization error - add custom serialization for "bytes" in `_URLOrData` ## Test Plan ``` pytest -v -s -k "fireworks" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming ``` **Before** <img width="1020" alt="image" src="https://github.com/user-attachments/assets/3803fcee-32ee-4b8e-ba46-47848e1a6247" /> **After** <img width="1018" alt="image" src="https://github.com/user-attachments/assets/f3e3156e-88ce-40fd-ad1b-44b87f376e03" /> <img width="822" alt="image" src="https://github.com/user-attachments/assets/1898696f-95c0-4694-8a47-8f51c7de0e86" /> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
79f8bc8416
commit
694adb1501
1 changed files with 8 additions and 1 deletions
|
@ -4,11 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_serializer, model_validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -27,6 +28,12 @@ class _URLOrData(BaseModel):
|
|||
return values
|
||||
return {"url": values}
|
||||
|
||||
@field_serializer("data")
|
||||
def serialize_data(self, data: Optional[bytes], _info):
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageContentItem(_URLOrData):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue