mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) ollama use /api/chat
This commit is contained in:
parent
da4ec6c8b6
commit
d85c19394f
1 changed files with 2 additions and 74 deletions
|
@ -132,10 +132,10 @@ def get_ollama_response(
|
||||||
model_response=None,
|
model_response=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
if api_base.endswith("/api/generate"):
|
if api_base.endswith("/api/chat"):
|
||||||
url = api_base
|
url = api_base
|
||||||
else:
|
else:
|
||||||
url = f"{api_base}/api/generate"
|
url = f"{api_base}/api/chat"
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.OllamaConfig.get_config()
|
config = litellm.OllamaConfig.get_config()
|
||||||
|
@ -329,75 +329,3 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def ollama_aembeddings(
|
|
||||||
api_base="http://localhost:11434",
|
|
||||||
model="llama2",
|
|
||||||
prompt="Why is the sky blue?",
|
|
||||||
optional_params=None,
|
|
||||||
logging_obj=None,
|
|
||||||
model_response=None,
|
|
||||||
encoding=None,
|
|
||||||
):
|
|
||||||
if api_base.endswith("/api/embeddings"):
|
|
||||||
url = api_base
|
|
||||||
else:
|
|
||||||
url = f"{api_base}/api/embeddings"
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.OllamaConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
}
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=None,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}},
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
response = await session.post(url, json=data)
|
|
||||||
|
|
||||||
if response.status != 200:
|
|
||||||
text = await response.text()
|
|
||||||
raise OllamaError(status_code=response.status, message=text)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key="",
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={
|
|
||||||
"headers": None,
|
|
||||||
"api_base": api_base,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response_json = await response.json()
|
|
||||||
embeddings = response_json["embedding"]
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
output_data = []
|
|
||||||
for idx, embedding in enumerate(embeddings):
|
|
||||||
output_data.append(
|
|
||||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
|
||||||
)
|
|
||||||
model_response["object"] = "list"
|
|
||||||
model_response["data"] = output_data
|
|
||||||
model_response["model"] = model
|
|
||||||
|
|
||||||
input_tokens = len(encoding.encode(prompt))
|
|
||||||
|
|
||||||
model_response["usage"] = {
|
|
||||||
"prompt_tokens": input_tokens,
|
|
||||||
"total_tokens": input_tokens,
|
|
||||||
}
|
|
||||||
return model_response
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue