forked from phoenix-oss/llama-stack-mirror
Compare commits
1 commit
kvant
...
inject-met
Author | SHA1 | Date | |
---|---|---|---|
|
afb81da91a |
3 changed files with 1856 additions and 1011 deletions
1749
docs/_static/llama-stack-spec.html
vendored
1749
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1093
docs/_static/llama-stack-spec.yaml
vendored
1093
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -10,9 +10,10 @@ import inspect
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
from pydantic import BaseModel, create_model
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
from llama_stack.apis.telemetry.telemetry import MetricEvent
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from ..strong_typing.inspection import get_signature
|
from ..strong_typing.inspection import get_signature
|
||||||
|
@ -303,8 +304,28 @@ def get_endpoint_operations(
|
||||||
return typing._UnionGenericAlias(typing.Union, tuple(types))
|
return typing._UnionGenericAlias(typing.Union, tuple(types))
|
||||||
else:
|
else:
|
||||||
return t
|
return t
|
||||||
|
def augment_response_with_metrics(t):
|
||||||
|
if t in (int, float, str, list):
|
||||||
|
return t
|
||||||
|
elif typing.get_origin(t) is typing.Union:
|
||||||
|
types = [augment_response_with_metrics(a) for a in typing.get_args(t)]
|
||||||
|
return typing._UnionGenericAlias(typing.Union, tuple(types))
|
||||||
|
elif isinstance(t, type) and issubclass(t, BaseModel):
|
||||||
|
if "metric_events" in t.model_fields:
|
||||||
|
print(f"warning: {t.__name__} already has metric_events field")
|
||||||
|
return t
|
||||||
|
if "data" in t.model_fields:
|
||||||
|
print(f"warning: {t.__name__} has a data field, metrics are not added")
|
||||||
|
return t
|
||||||
|
return create_model(
|
||||||
|
t.__name__,
|
||||||
|
__base__=t,
|
||||||
|
metrics=(Optional[List[MetricEvent]], None)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return t
|
||||||
|
|
||||||
response_type = process_type(return_type)
|
response_type = augment_response_with_metrics(process_type(return_type))
|
||||||
|
|
||||||
if prefix in ["delete", "remove"]:
|
if prefix in ["delete", "remove"]:
|
||||||
http_method = HTTPMethod.DELETE
|
http_method = HTTPMethod.DELETE
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue