diff --git a/.env b/.env index f2ce2f9..8343c5a 100644 --- a/.env +++ b/.env @@ -4,9 +4,10 @@ LOG_CHAT=true # `CACHE_BACKEND`: Options (MEMORY, LMDB, LevelDB) -CACHE_BACKEND=LMDB +CACHE_BACKEND=leveldb CACHE_ROOT_PATH_OR_URL=./FLAXKV_DB CACHE_CHAT_COMPLETION=true +CACHE_EMBEDDING=true DEFAULT_REQUEST_CACHING_VALUE=false #LOG_CACHE_DB_INFO=true @@ -25,6 +26,7 @@ OPENAI_ROUTE_PREFIX=/ CHAT_COMPLETION_ROUTE=/v1/chat/completions COMPLETION_ROUTE=/v1/completions +EMBEDDING_ROUTE=/v1/embeddings # `EXTRA_BASE_URL`: Specify any service for forwarding #EXTRA_BASE_URL= diff --git a/Examples/chat_completion.py b/Examples/chat_completion.py index dd8384e..002e247 100644 --- a/Examples/chat_completion.py +++ b/Examples/chat_completion.py @@ -9,74 +9,36 @@ api_key=config['api_key'], base_url=config['api_base'], ) + stream = True -# stream = False n = 1 -# debug = True -debug = False - -is_tool_calls = True -# is_tool_call = False -caching = True +debug, caching = False, True max_tokens = None user_content = """ 用c实现目前已知最快平方根算法 """ -user_content = 'hi' +# user_content = '最初有1000千克的蘑菇,其中99%的成分是水。经过几天的晴天晾晒后,蘑菇中的水分含量现在是98%,蘑菇中减少了多少水分?' +user_content = "讲个简短的笑话" model = "gpt-3.5-turbo" # model="gpt-4" mt = MeasureTime().start() -if is_tool_calls: - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, - } - ] - resp = client.chat.completions.create( - model=model, - messages=[ - {"role": "user", "content": "What's the weather like in Boston today?"} - ], - tools=tools, - tool_choice="auto", # auto is default, but we'll be explicit - stream=stream, - extra_body={"caching": caching}, - ) - -else: - resp = client.chat.completions.create( - model=model, - messages=[ - {"role": "user", "content": user_content}, - ], - stream=stream, - n=n, - max_tokens=max_tokens, - timeout=30, - # extra_headers=(caching, caching) - extra_body={"caching": caching}, - ) +resp = client.chat.completions.create( + model=model, + messages=[ + {"role": "user", "content": user_content}, + ], + stream=stream, + n=n, + max_tokens=max_tokens, + timeout=30, + extra_body={"caching": caching}, +) if stream: if debug: @@ -86,23 +48,10 @@ for idx, chunk in enumerate(resp): chunk_message = chunk.choices[0].delta or "" if idx == 0: - if is_tool_calls: - function = chunk_message.tool_calls[0].function - name = function.name - print(f"{chunk_message.role}: \n{name}: ") - else: - print(f"{chunk_message.role}: ") + mt.show_interval("tcp time:") + print(f"{chunk_message.role}: ") continue - - content = "" - if is_tool_calls: - tool_calls = chunk_message.tool_calls - if tool_calls: - function = tool_calls[0].function - if function: - content = function.arguments or "" - else: - content = chunk_message.content or "" + content = chunk_message.content or "" print(content, end="") print() else: diff --git a/Examples/embedding.py b/Examples/embedding.py index a88bcd7..4e22e3d 100644 --- a/Examples/embedding.py +++ b/Examples/embedding.py @@ -1,11 +1,18 @@ -import openai +from openai import OpenAI from sparrow import yaml_load config = yaml_load("config.yaml") -openai.api_base = config["api_base"] -openai.api_key = config["api_key"] -response = openai.Embedding.create( - input="Your text string goes here", model="text-embedding-ada-002" +client = OpenAI( + api_key=config['api_key'], + base_url=config['api_base'], ) -embeddings = response['data'][0]['embedding'] -print(embeddings) + +response = client.embeddings.create( + model="text-embedding-ada-002", + input="你好", + encoding_format="float", + # encoding_format="base64", + extra_body={"caching": True}, +) + +print(response) diff --git a/openai_forward/__init__.py b/openai_forward/__init__.py index 02465c1..87fa7a8 100644 --- a/openai_forward/__init__.py +++ b/openai_forward/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.10" +__version__ = "0.6.11" from dotenv import load_dotenv diff --git a/openai_forward/app.py b/openai_forward/app.py index a30dee6..778e9e6 100644 --- a/openai_forward/app.py +++ b/openai_forward/app.py @@ -54,7 +54,7 @@ def healthz(request: Request): if BENCHMARK_MODE: - from .cache.chat_completions import chat_completions_benchmark + from .cache.chat.chat_completions import chat_completions_benchmark app.add_route( "/benchmark/v1/chat/completions", diff --git a/openai_forward/cache/__init__.py b/openai_forward/cache/__init__.py index e69de29..a2377b9 100644 --- a/openai_forward/cache/__init__.py +++ b/openai_forward/cache/__init__.py @@ -0,0 +1,45 @@ +from ..settings import ( + CACHE_CHAT_COMPLETION, + CACHE_EMBEDDING, + CHAT_COMPLETION_ROUTE, + EMBEDDING_ROUTE, +) +from .chat.response import get_cached_chat_response +from .database import db_dict +from .embedding.response import get_cached_embedding_response + + +def get_cached_response(payload_info, valid_payload, route_path, request): + if route_path == EMBEDDING_ROUTE: + return get_cached_embedding_response(payload_info, valid_payload, request) + elif route_path == CHAT_COMPLETION_ROUTE: + return get_cached_chat_response(payload_info, valid_payload, request) + else: + return None, None + + +def cache_response(cache_key, target_info, route_path): + if ( + target_info + and CACHE_CHAT_COMPLETION + and route_path == CHAT_COMPLETION_ROUTE + and cache_key is not None + ): + cached_value = db_dict.get(cache_key, {"data": []})["data"] + if len(cached_value) < 10: + cached_value.append(target_info["assistant"]) + db_dict[cache_key] = { + "data": cached_value, + "route_path": route_path, + } + elif ( + target_info + and CACHE_EMBEDDING + and route_path == EMBEDDING_ROUTE + and cache_key is not None + ): + cached_value = bytes(target_info["buffer"]) + db_dict[cache_key] = { + "data": cached_value, + "route_path": route_path, + } diff --git a/openai_forward/cache/chat/__init__.py b/openai_forward/cache/chat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/openai_forward/cache/chat_completions.py b/openai_forward/cache/chat/chat_completions.py similarity index 95% rename from openai_forward/cache/chat_completions.py rename to openai_forward/cache/chat/chat_completions.py index 5344fc8..3827e3c 100644 --- a/openai_forward/cache/chat_completions.py +++ b/openai_forward/cache/chat/chat_completions.py @@ -9,9 +9,9 @@ from fastapi import Request from fastapi.responses import Response, StreamingResponse -from ..decorators import async_random_sleep, async_token_rate_limit -from ..helper import get_unique_id -from ..settings import token_interval_conf +from ...decorators import async_random_sleep, async_token_rate_limit +from ...helper import get_unique_id +from ...settings import token_interval_conf from .tokenizer import TIKTOKEN_VALID, count_tokens, encode_as_pieces diff --git a/openai_forward/cache/chat/response.py b/openai_forward/cache/chat/response.py new file mode 100644 index 0000000..cb2006f --- /dev/null +++ b/openai_forward/cache/chat/response.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import random + +from fastapi.responses import Response, StreamingResponse +from flaxkv.pack import encode +from loguru import logger + +from ...settings import CACHE_CHAT_COMPLETION +from ..database import db_dict +from .chat_completions import generate, stream_generate_efficient + + +def construct_cache_key(payload_info: dict): + elements = [ + payload_info["n"], + payload_info['messages'], + payload_info['model'], + payload_info["max_tokens"], + payload_info['response_format'], + payload_info['seed'], + # payload_info['temperature'], + payload_info["tools"], + payload_info["tool_choice"], + ] + + return encode(elements) + + +def get_response_from_key(key, payload_info, request): + value = db_dict[key] + cache_values = value.get('data') + + if cache_values is None: # deprecate soon + # compatible with previous version + cache_values = value + idx = random.randint(0, len(cache_values) - 1) if len(cache_values) > 1 else 0 + logger.info(f'chat uid: {payload_info["uid"]} >>>{idx}>>>> [cache hit]') + # todo: handle multiple choices + cache_value = cache_values[idx] + if isinstance(cache_value, list): + text = None + tool_calls = cache_value + else: + text = cache_value + tool_calls = None + + if payload_info["stream"]: + return StreamingResponse( + stream_generate_efficient( + payload_info['model'], + text, + tool_calls, + request, + ), + status_code=200, + media_type="text/event-stream", + ) + + else: + usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + return Response( + content=generate(payload_info['model'], text, tool_calls, usage), + media_type="application/json", + ) + + +def get_cached_chat_response(payload_info, valid_payload, request): + """ + Attempts to retrieve a cached response based on the current request's payload information. + + This function constructs a cache key based on various aspects of the request payload, + checks if the response for this key has been cached, and if so, constructs and returns + the appropriate cached response. + + Returns: + Tuple[Union[Response, None], Union[str, None]]: + - Response (Union[Response, None]): The cached response if available; otherwise, None. + - cache_key (Union[str, None]): The constructed cache key for the request. None if caching is not applicable. + + Note: + If a cache hit occurs, the cached response is immediately returned without contacting the external server. + """ + if not (CACHE_CHAT_COMPLETION and valid_payload): + return None, None + + cache_key = construct_cache_key(payload_info) + + if payload_info['caching'] and cache_key in db_dict: + return get_response_from_key(cache_key, payload_info, request), cache_key + + return None, cache_key diff --git a/openai_forward/cache/tokenizer.py b/openai_forward/cache/chat/tokenizer.py similarity index 100% rename from openai_forward/cache/tokenizer.py rename to openai_forward/cache/chat/tokenizer.py diff --git a/openai_forward/cache/database.py b/openai_forward/cache/database.py index 42e96f7..346c94b 100644 --- a/openai_forward/cache/database.py +++ b/openai_forward/cache/database.py @@ -7,22 +7,16 @@ elif CACHE_BACKEND.lower() in ("leveldb", "lmdb"): - if LOG_CACHE_DB_INFO: - db_dict = FlaxKV( - "CACHE_DB", - root_path_or_url=CACHE_ROOT_PATH_OR_URL, - backend=CACHE_BACKEND.lower(), - cache=True, - log="INFO", - save_log=True, - ) - else: - db_dict = FlaxKV( - "CACHE_DB", - root_path_or_url=CACHE_ROOT_PATH_OR_URL, - backend=CACHE_BACKEND.lower(), - cache=True, - ) + log, save_log = ("INFO", True) if LOG_CACHE_DB_INFO else (None, False) + + db_dict = FlaxKV( + "CACHE_DB", + root_path_or_url=CACHE_ROOT_PATH_OR_URL, + backend=CACHE_BACKEND.lower(), + cache=True, + log=log, + save_log=save_log, + ) else: raise ValueError( diff --git a/openai_forward/cache/embedding/__init__.py b/openai_forward/cache/embedding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/openai_forward/cache/embedding/response.py b/openai_forward/cache/embedding/response.py new file mode 100644 index 0000000..079c509 --- /dev/null +++ b/openai_forward/cache/embedding/response.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from fastapi.responses import Response +from flaxkv.pack import encode +from loguru import logger + +from ...settings import CACHE_EMBEDDING +from ..database import db_dict + + +def construct_cache_key(payload_info): + elements = [ + payload_info['model'], + payload_info['input'], + payload_info['encoding_format'], + ] + return encode(elements) + + +def get_cached_embedding_response(payload_info, valid_payload, request): + """ + Attempts to retrieve a cached response based on the current request's payload information. + + Note: + If a cache hit occurs, the cached response is immediately returned without contacting the external server. + """ + if not (CACHE_EMBEDDING and valid_payload): + return None, None + + cache_key = construct_cache_key(payload_info) + + if payload_info['caching'] and cache_key in db_dict: + logger.info(f'embedding uid: {payload_info["uid"]} >>>>> [cache hit]') + cache_data = db_dict[cache_key] + content = cache_data['data'] + + return Response(content=content, media_type="application/json"), cache_key + + return None, cache_key diff --git a/openai_forward/content/config.py b/openai_forward/content/config.py index cd78a81..1b9760e 100644 --- a/openai_forward/content/config.py +++ b/openai_forward/content/config.py @@ -85,6 +85,13 @@ def get_utc_offset(timezone_str): "filter": lambda record: f"{_prefix}_completion" in record["extra"], "format": "{message}", }, + { + "sink": f"./Log/{_prefix}/embedding/embedding.log", + "enqueue": multi_process, + "rotation": "100 MB", + "filter": lambda record: f"{_prefix}_embedding" in record["extra"], + "format": "{message}", + }, { "sink": f"./Log/{_prefix}/whisper/whisper.log", "enqueue": multi_process, diff --git a/openai_forward/content/openai.py b/openai_forward/content/openai.py index aadf8b6..2ea629d 100644 --- a/openai_forward/content/openai.py +++ b/openai_forward/content/openai.py @@ -1,5 +1,6 @@ import time -from typing import List +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple import orjson from fastapi import Request @@ -11,7 +12,27 @@ from .helper import markdown_print, parse_sse_buffer, print -class CompletionLogger: +class LoggerBase(ABC): + def __init__(self, route_prefix: str, _suffix: str): + _prefix = route_prefix_to_str(route_prefix) + kwargs = {f"{_prefix}{_suffix}": True} + self.logger = logger.bind(**kwargs) + + @staticmethod + @abstractmethod + def parse_payload(request: Request, raw_payload) -> Tuple[Dict, bytes]: + pass + + @staticmethod + @abstractmethod + def parse_bytearray(buffer: bytearray) -> Dict: + pass + + def log_result(self, *args, **kwargs): + pass + + +class CompletionLogger(LoggerBase): def __init__(self, route_prefix: str): """ Initialize the Completions logger with a route prefix. @@ -19,14 +40,12 @@ def __init__(self, route_prefix: str): Args: route_prefix (str): The prefix used for routing, e.g., '/openai'. """ - _prefix = route_prefix_to_str(route_prefix) - kwargs = {_prefix + "_completion": True} - self.logger = logger.bind(**kwargs) + super().__init__(route_prefix, "_completion") @staticmethod - async def parse_payload(request: Request): + def parse_payload(request: Request, raw_payload): uid = get_unique_id() - payload = await request.json() + payload = orjson.loads(raw_payload) print(f"{payload=}") content = { @@ -38,7 +57,7 @@ async def parse_payload(request: Request): "uid": uid, "datetime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), } - return content, payload + return content, raw_payload @staticmethod def parse_bytearray(buffer: bytearray): @@ -88,7 +107,7 @@ def parse_bytearray(buffer: bytearray): target_info['text'] = text return target_info - def log(self, chat_info: dict): + def log_result(self, chat_info: dict): """ Log chat information to the logger bound to this instance. @@ -98,7 +117,7 @@ def log(self, chat_info: dict): self.logger.debug(f"{chat_info}") -class ChatLogger: +class ChatLogger(LoggerBase): def __init__(self, route_prefix: str): """ Initialize the Chat Completions logger with a route prefix. @@ -106,12 +125,10 @@ def __init__(self, route_prefix: str): Args: route_prefix (str): The prefix used for routing, e.g., '/openai'. """ - _prefix = route_prefix_to_str(route_prefix) - kwargs = {_prefix + "_chat": True} - self.logger = logger.bind(**kwargs) + super().__init__(route_prefix, "_chat") @staticmethod - async def parse_payload(request: Request): + def parse_payload(request: Request, raw_payload): """ Asynchronously parse the payload from a FastAPI request. @@ -122,42 +139,38 @@ async def parse_payload(request: Request): dict: A dictionary containing parsed messages, model, IP address, UID, and datetime. """ uid = get_unique_id() - payload = await request.json() - - # functions = payload.get("functions") # deprecated - # if functions: - # info = { - # "functions": functions, # Deprecated in favor of `tools` - # "function_call": payload.get("function_call", None), # Deprecated in favor of `tool_choice` - # } - info = {} - info.update( - { - "messages": payload["messages"], - "model": payload["model"], - "stream": payload.get("stream", False), - "max_tokens": payload.get("max_tokens", None), - "response_format": payload.get("response_format", None), - "n": payload.get("n", 1), - "temperature": payload.get("temperature", 1), - "top_p": payload.get("top_p", 1), - "logit_bias": payload.get("logit_bias", None), - "frequency_penalty": payload.get("frequency_penalty", 0), - "presence_penalty": payload.get("presence_penalty", 0), - "seed": payload.get("seed", None), - "stop": payload.get("stop", None), - "user": payload.get("user", None), - "tools": payload.get("tools", None), - "tool_choice": payload.get("tool_choice", None), - "ip": get_client_ip(request) or "", - "uid": uid, - "caching": payload.pop( - "caching", DEFAULT_REQUEST_CACHING_VALUE - ), # pop caching - "datetime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), - } - ) - return info, orjson.dumps(payload) + payload = orjson.loads(raw_payload) + caching = payload.pop("caching", None) + if caching is None: + caching = DEFAULT_REQUEST_CACHING_VALUE + payload_return = raw_payload + else: + payload_return = orjson.dumps(payload) + + info = { + "messages": payload["messages"], + "model": payload["model"], + "stream": payload.get("stream", False), + "max_tokens": payload.get("max_tokens", None), + "response_format": payload.get("response_format", None), + "n": payload.get("n", 1), + "temperature": payload.get("temperature", 1), + "top_p": payload.get("top_p", 1), + "logit_bias": payload.get("logit_bias", None), + "frequency_penalty": payload.get("frequency_penalty", 0), + "presence_penalty": payload.get("presence_penalty", 0), + "seed": payload.get("seed", None), + "stop": payload.get("stop", None), + "user": payload.get("user", None), + "tools": payload.get("tools", None), + "tool_choice": payload.get("tool_choice", None), + "ip": get_client_ip(request) or "", + "uid": uid, + "caching": caching, + "datetime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + } + + return info, payload_return def parse_bytearray(self, buffer: bytearray): """ @@ -267,6 +280,15 @@ def log(self, chat_info: dict): """ self.logger.debug(f"{chat_info}") + def log_result(self, chat_info: dict): + """ + Log chat information to the logger bound to this instance. + + Args: + chat_info (dict): A dictionary containing chat information to be logged. + """ + self.logger.debug(f"{chat_info}") + @staticmethod def print_chat_info(chat_info: dict): """ @@ -295,6 +317,60 @@ def print_chat_info(chat_info: dict): print(77 * "=", role='assistant') +class EmbeddingLogger: + def __init__(self, route_prefix: str): + _prefix = route_prefix_to_str(route_prefix) + kwargs = {_prefix + "_embedding": True} + self.logger = logger.bind(**kwargs) + + @staticmethod + def parse_payload(request: Request, raw_payload: bytes): + uid = get_unique_id() + payload = orjson.loads(raw_payload) + caching = payload.pop("caching", None) + if caching is None: + caching = DEFAULT_REQUEST_CACHING_VALUE + payload_return = raw_payload + else: + payload_return = orjson.dumps(payload) + + content = { + "input": payload['input'], + "model": payload['model'], + "encoding_format": payload.get("encoding_format", 'float'), + "ip": get_client_ip(request) or "", + "uid": uid, + "caching": caching, + "datetime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + } + return content, payload_return + + def parse_bytearray(self, buffer: bytearray): + """ + Parse a bytearray into a dictionary. + """ + result_dict = orjson.loads(buffer) + target_info = { + "object": result_dict["object"], + "usage": result_dict["usage"], + "model": result_dict["model"], + "buffer": buffer, + } + return target_info + + def log(self, info: dict): + self.logger.debug(f"{info}") + + def log_result(self, info: dict): + result_info = { + "object": info["object"], + "usage": info["usage"], + "model": info["model"], + "uid": info["uid"], + } + self.logger.debug(f"{result_info}") + + class WhisperLogger: def __init__(self, route_prefix: str): """ diff --git a/openai_forward/forward/base.py b/openai_forward/forward/base.py index 756c099..5967156 100644 --- a/openai_forward/forward/base.py +++ b/openai_forward/forward/base.py @@ -10,13 +10,16 @@ import anyio from aiohttp import TCPConnector from fastapi import HTTPException, Request, status -from flaxkv.pack import encode from loguru import logger -from starlette.responses import BackgroundTask, Response, StreamingResponse - -from ..cache.chat_completions import generate, stream_generate_efficient -from ..cache.database import db_dict -from ..content.openai import ChatLogger, CompletionLogger, WhisperLogger +from starlette.responses import BackgroundTask, StreamingResponse + +from ..cache import cache_response, get_cached_response +from ..content.openai import ( + ChatLogger, + CompletionLogger, + EmbeddingLogger, + WhisperLogger, +) from ..decorators import async_retry, async_token_rate_limit from ..settings import * @@ -199,6 +202,7 @@ def __init__(self, base_url: str, route_prefix: str, proxy=None): self.chat_logger = ChatLogger(self.ROUTE_PREFIX) self.completion_logger = CompletionLogger(self.ROUTE_PREFIX) self.whisper_logger = WhisperLogger(self.ROUTE_PREFIX) + self.embedding_logger = EmbeddingLogger(self.ROUTE_PREFIX) def _handle_result( self, buffer: bytearray, uid: str, route_path: str, request_method: str @@ -228,6 +232,8 @@ def _handle_result( logger_instance = self.chat_logger elif route_path == COMPLETION_ROUTE: logger_instance = self.completion_logger + elif route_path == EMBEDDING_ROUTE: + logger_instance = self.embedding_logger elif route_path.startswith("/v1/audio/"): self.whisper_logger.log_buffer(buffer) return result_info @@ -238,7 +244,7 @@ def _handle_result( result_info["uid"] = uid if LOG_CHAT: - logger_instance.log(result_info) + logger_instance.log_result(result_info) if PRINT_CHAT and logger_instance == self.chat_logger: self.chat_logger.print_chat_info(result_info) @@ -265,8 +271,9 @@ async def _handle_payload(self, request: Request, url_path: str): """ payload_log_info = {"uid": None} + payload = await request.body() + if not (LOG_CHAT or PRINT_CHAT) or request.method != "POST": - payload = await request.body() return False, payload_log_info, payload try: @@ -276,13 +283,17 @@ async def _handle_payload(self, request: Request, url_path: str): logger_instance = self.chat_logger elif url_path == COMPLETION_ROUTE: logger_instance = self.completion_logger + elif url_path == EMBEDDING_ROUTE: + logger_instance = self.embedding_logger # If a logger method is determined, parse payload and log if necessary if logger_instance: - payload_log_info, payload = await logger_instance.parse_payload(request) + payload_log_info, payload = logger_instance.parse_payload( + request, payload + ) if payload_log_info and LOG_CHAT: - logger_instance.log(payload_log_info) + logger_instance.logger.debug(payload_log_info) if ( payload_log_info @@ -291,13 +302,12 @@ async def _handle_payload(self, request: Request, url_path: str): ): self.chat_logger.print_chat_info(payload_log_info) else: - payload = await request.body() + ... except Exception as e: logger.warning( f"log chat error:\nhost:{request.client.host} method:{request.method}: {traceback.format_exc()}" ) - payload = await request.body() valid = True if payload_log_info['uid'] is not None else False return valid, payload_log_info, payload @@ -372,10 +382,8 @@ async def aiter_bytes( target_info = self._handle_result( chunk, uid, route_path, request.method ) - if target_info and CACHE_CHAT_COMPLETION and cache_key is not None: - cached_value = db_dict.get(cache_key, []) - cached_value.append(target_info["assistant"]) - db_dict[cache_key] = cached_value + cache_response(cache_key, target_info, route_path) + elif chunk is not None: logger.warning( f'uid: {uid}\n status: {r.status}\n {chunk.decode("utf-8")}' @@ -390,85 +398,6 @@ def handle_authorization(self, client_config): client_config["headers"]["Authorization"] = auth return auth - @staticmethod - def _get_cached_response(payload_info, valid_payload, request): - """ - Attempts to retrieve a cached response based on the current request's payload information. - - This function constructs a cache key based on various aspects of the request payload, - checks if the response for this key has been cached, and if so, constructs and returns - the appropriate cached response. - - Returns: - Tuple[Union[Response, None], Union[str, None]]: - - Response (Union[Response, None]): The cached response if available; otherwise, None. - - cache_key (Union[str, None]): The constructed cache key for the request. None if caching is not applicable. - - Note: - If a cache hit occurs, the cached response is immediately returned without contacting the external server. - """ - # todo: refactor this function - - def construct_cache_key(): - elements = [ - payload_info["n"], - payload_info['messages'], - payload_info['model'], - payload_info["max_tokens"], - payload_info['response_format'], - payload_info['seed'], - # payload_info['temperature'], - payload_info["tools"], - payload_info["tool_choice"], - ] - - return encode(elements) - - def get_response_from_cache(key): - logger.info(f'uid: {payload_info["uid"]} >>>>> [cache hit]') - cache_values = db_dict[key] - # todo: handle multiple choices - cache_value = cache_values[-1] - if isinstance(cache_value, list): - text = None - tool_calls = cache_value - else: - text = cache_value - tool_calls = None - - if payload_info["stream"]: - return StreamingResponse( - stream_generate_efficient( - payload_info['model'], - text, - tool_calls, - request, - ), - status_code=200, - media_type="text/event-stream", - ) - - else: - usage = { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - } - return Response( - content=generate(payload_info['model'], text, tool_calls, usage), - media_type="application/json", - ) - - if not (CACHE_CHAT_COMPLETION and valid_payload): - return None, None - - cache_key = construct_cache_key() - - if payload_info['caching'] and cache_key in db_dict: - return get_response_from_cache(cache_key), cache_key - - return None, cache_key - async def reverse_proxy(self, request: Request): """ Asynchronously handles reverse proxying the incoming request. @@ -480,23 +409,23 @@ async def reverse_proxy(self, request: Request): StreamingResponse: A FastAPI StreamingResponse containing the server's response. """ client_config = self.prepare_client(request, return_origin_header=False) - url_path = client_config["url_path"] + route_path = client_config["url_path"] self.handle_authorization(client_config) valid_payload, payload_info, payload = await self._handle_payload( - request, url_path + request, route_path ) uid = payload_info["uid"] - cached_response, cache_key = self._get_cached_response( - payload_info, valid_payload, request + cached_response, cache_key = get_cached_response( + payload_info, valid_payload, route_path, request ) if cached_response: return cached_response r = await self.send(client_config, data=payload) return StreamingResponse( - self.aiter_bytes(r, request, url_path, uid, cache_key), + self.aiter_bytes(r, request, route_path, uid, cache_key), status_code=r.status, media_type=r.headers.get("content-type"), ) diff --git a/openai_forward/settings.py b/openai_forward/settings.py index ccc941b..a6a5f2f 100644 --- a/openai_forward/settings.py +++ b/openai_forward/settings.py @@ -21,6 +21,7 @@ ) COMPLETION_ROUTE = os.environ.get("COMPLETION_ROUTE", "").strip() or "/v1/completions" +EMBEDDING_ROUTE = os.environ.get("EMBEDDING_ROUTE", "").strip() or "/v1/embeddings" ENV_VAR_SEP = "," OPENAI_BASE_URL = env2list("OPENAI_BASE_URL", sep=ENV_VAR_SEP) or [ @@ -49,6 +50,7 @@ if PRINT_CHAT: additional_start_info["print_chat"] = True +CACHE_EMBEDDING = os.environ.get("CACHE_EMBEDDING", "false").strip().lower() == "true" CACHE_CHAT_COMPLETION = ( os.environ.get("CACHE_CHAT_COMPLETION", "false").strip().lower() == "true" )