forked from phoenix/litellm-mirror
olla upgrades, fix streaming, add non streaming resp
This commit is contained in:
parent
6cb03d7c63
commit
56bd8c1c52
5 changed files with 135 additions and 86 deletions
|
@ -28,6 +28,7 @@ from .llms import replicate
|
|||
from .llms import aleph_alpha
|
||||
from .llms import baseten
|
||||
from .llms import vllm
|
||||
from .llms import ollama
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict
|
||||
|
@ -39,9 +40,6 @@ from litellm.utils import (
|
|||
ModelResponse,
|
||||
read_config_args,
|
||||
)
|
||||
from litellm.utils import (
|
||||
get_ollama_response_stream,
|
||||
)
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
@ -728,10 +726,27 @@ def completion(
|
|||
logging.pre_call(
|
||||
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
||||
)
|
||||
|
||||
generator = get_ollama_response_stream(endpoint, model, prompt)
|
||||
# assume all responses are streamed
|
||||
return generator
|
||||
generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
|
||||
if optional_params.get("stream", False) == True:
|
||||
# assume all ollama responses are streamed
|
||||
return generator
|
||||
else:
|
||||
response_string = ""
|
||||
for chunk in generator:
|
||||
response_string+=chunk['choices'][0]['delta']['content']
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = response_string
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(response_string))
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
response = model_response
|
||||
elif (
|
||||
custom_llm_provider == "baseten"
|
||||
or litellm.api_base == "https://app.baseten.co"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue