litellm/litellm/llms/OpenAI/chat/o1_handler.py
Krish Dholakia 234185ec13
LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) (#5731)
* LiteLLM Minor Fixes & Improvements (09/16/2024)  (#5723)

* coverage (#5713)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Move (#5714)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix(litellm_logging.py): fix logging client re-init (#5710)

Fixes https://github.com/BerriAI/litellm/issues/5695

* fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config

Fixes https://github.com/BerriAI/litellm/issues/5682

* feat(o1_handler.py): fake streaming for openai o1 models

Fixes https://github.com/BerriAI/litellm/issues/5694

* docs: deprecated traceloop integration in favor of native otel (#5249)

* fix: fix linting errors

* fix: fix linting errors

* fix(main.py): fix o1 import

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730)

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view

Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it

* fix(custom_logger.py): reset calltype

* fix: fix linting errors

* fix: fix linting error

* fix: fix import

* test(test_databricks.py): fix databricks tests

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
2024-09-17 08:05:52 -07:00

95 lines
2.9 KiB
Python

"""
Handler file for calls to OpenAI's o1 family of models
Written separately to handle faking streaming for o1 models.
"""
import asyncio
from typing import Any, Callable, List, Optional, Union
from httpx._config import Timeout
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
class OpenAIO1ChatCompletion(OpenAIChatCompletion):
async def mock_async_streaming(
self,
response: Any,
model: Optional[str],
logging_obj: Any,
):
model_response = await response
completion_stream = MockResponseIterator(model_response=model_response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streaming_response
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
stream: Optional[bool] = optional_params.pop("stream", False)
response = super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)
if stream is True:
if asyncio.iscoroutine(response):
return self.mock_async_streaming(
response=response, model=model, logging_obj=logging_obj # type: ignore
)
completion_stream = MockResponseIterator(model_response=response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streaming_response
else:
return response