python multiprocessing多进程导致数据库连接不可用问题
问题背景
公司的算法服务为了能够管理每次算法服务请求的计算耗时同时使算法服务能够充分利用CPU的多核处理能力,在响应算法请求时使用了pebble包中的concurrent.process注解,使每个请求的处理由单开的进程来完成,并能够设置进程处理的超时时长。
产生的问题
在算法的处理中需要与pg库操作,使用了psycopg2包来建立pg库连接池。在linux系统下默认使用fork模式创建新进程。 在实际算法服务运行时,发现当出现并发操作时数据库连接会报如下异常:
psycopg2.OperationalError: lost synchronization with server: got message type "
type后面的内容每次可能不太一样。 经过研究发现出现这种情况的问题在于同一个连接同时被两个以上的进程操作导致。这其中包含对以下几个问题的研究,做一下记录。
涉及的知识点
psycopg2的数据库连接池
psycopg2的AbstractConnectionPool类中包含了对连接池中获取连接和连接放回的实现。连接池是由python的list结构维持,取连接使用pop()方法,放回连接使用append()方法。因此连接池的取放操作类似于栈的行为。如果一个连接池只有一个线程在使用,那么每次取到的一定是列表末尾的那个连接。
进程的初始化方式
multiprocessing包提供三种进程初始化方式,spawn, fork, forkserver。其中:
spawn方式不会继承原进程中的文件描述符、网络连接等,所以大部分资源需要重新初始化,所以效率低一些,但是比较安全; fork方式会以copy on write的方式继承原进程的所有数据,包括文件描述符、网络连接等。所以效率比较高,但是不安全; forkserver方式的表现和spawn比较接近,但是速度会更快一些;
三种方式的具体解释可以参考python文档:https://docs.python.org/3.10/library/multiprocessing.html#contexts-and-start-methods
问题的原因
所以,问题的原因就是算法服务启动新进程来响应算法请求时使用了fork方式初始化子进程,从而继承了父进程的连接池数据。此时如果有并发的请求发生,另一个子进程也会继承父进程的同一份连接池数据,因此两个子进程存在同时使用数据库连接的情况,而根据对psycopg2的连接池实现的了解,两个子进程只要存在并发情况,必然会同时操作同一个数据库连接。就会造成上面的异常。 这个问题在psycopg2的文档中也有描述,参考https://www.psycopg.org/docs/usage.html#thread-and-process-safety。
解决思路
- 使用fork方式,需要注意新进程启动后关闭之前的连接(该操作不会造成连接在父进程中被关闭,参考文档:https://www.python.org/dev/peps/pep-0433/#inherited-file-descriptors-issues),重新获取新的连接,在进程结束前注意关闭连接;
- 使用spawn或者fork server方式创建进程;
python中的id()的返回
定位问题时发现子进程中继承的父进程中的数据库连接池以及所有连接在日志中打印出的id内容是完全相等的。这是因为在cpython中,id方法返回的是对象在进程中的地址,是相对于进程初始地址的偏移量,不是内存中的绝对地址,所以,只能保证进程内唯一,进程间是有可能出现id值相等的对象的。
TODO:要满足多核计算、流程耗时可控、资源共享最大化。算法服务应如何实现
代码附录
出现数据库连接冲突问题的简化代码如下:
from typing import Dict, Tuple
from functools import lru_cache, wraps
from datetime import datetime
import time
import random
import traceback
import os
import multiprocessing as mp
import psycopg2
from psycopg2 import pool
from loguru import logger
import pandas as pd
from pebble import concurrent
from pebble.common import ProcessExpired
from _timing import timing
logger.add('my_log.log')
class RDS_IO:
def __init__(self):
self._pool = None
def is_init(self):
"返回rds连接是否初始化"
if self._pool is None:
return False
return True
def init_pg(
self,
host: str,
port: str,
user: str,
password: str,
dbname: str,
minconn: int = 8,
maxconn: int = 16,
):
self._pool = psycopg2.pool.ThreadedConnectionPool(
minconn=minconn,
maxconn=maxconn,
host=host,
port=port,
dbname=dbname,
user=user,
password=password,
)
logger.info(f"成功初始化pg数据库连接池,连接数{minconn} - {maxconn}")
self.minconn = minconn
self.maxconn = maxconn
self.host = host
self.port = port
self.dbname = dbname
self.user = user
self.password = password
def reinit_pg(self):
self.close()
self.init_pg(
self.host,
self.port,
self.user,
self.password,
self.dbname,
self.minconn,
self.maxconn,
)
logger.info(f"重连数据库成功")
def pool(self):
return self._pool
def execute_sql(
self,
sql: str,
data_type: Dict = {},
):
time_tag = datetime.today().strftime("%y%m%d")
data_type = tuple(sorted(data_type.items()))
ret = self._execute_sql(
time_tag=time_tag,
sql=sql,
data_type=data_type,
)
if ret is None:
msg = f"取数据为空\n\n{sql}"
raise Exception(msg)
ret = ret.copy(deep=True)
return ret
def _execute_sql(
self,
time_tag,
sql: str,
data_type: Tuple,
) -> pd.DataFrame:
"""
time_tag: 用作时间标记,代表了缓存的有效期
"""
data_type = dict(data_type)
if not self.is_init():
msg = "pg数据库连接 未初始化,请用init_pg(*)初始化"
logger.critical(msg)
raise RuntimeError(msg)
logger.debug(f"query {sql}")
ok = False
max_retry_cnt = 5
for iid in range(max_retry_cnt):
if iid == 3:
self.reinit_pg()
if iid > 0:
sleep_sec = random.random() * 10
time.sleep(sleep_sec)
try:
print(f'[getconn前] 线程: {os.getpid()} 连接池锁id和状态: {id(self._pool._lock)}, {self._pool._lock.locked()}')
conn = self._pool.getconn()
print(f'[getconn后] 线程: {os.getpid()} 连接池锁id和状态: {id(self._pool._lock)}, {self._pool._lock.locked()}')
logger.info(f'[取] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
logger.info(
f'池状态: {self._pool.closed}, 池连接状态, {[enum.get_transaction_status() for enum in self._pool._pool]}, 池key-连接: {self._pool._used}, 池连接-key: {self._pool._rused}, 池key: {self._pool._keys}')
flag = False
with conn.cursor() as cur:
try:
cur.execute(sql)
rows = cur.fetchall()
col_names = []
for elt in cur.description:
col_names.append(elt[0])
ok = True
if iid > 0:
logger.info(f"第{iid+1}次 重试成功")
except Exception as e:
logger.error(f"第{iid+1}次 执行 sql 失败\n{e}, 丢弃该连接")
logger.error(f'[异常] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
raise e
logger.info(f'[放前] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
self._pool.putconn(conn, close=flag)
logger.info(f'[放后] {os.getpid()} 连接状态{conn.get_transaction_status()}, 池内剩余: {len(self._pool._pool)}, {conn}')
logger.info(
f'池状态: {self._pool.closed}, 池连接状态, {[enum.get_transaction_status() for enum in self._pool._pool]}, 池key-连接: {self._pool._used}, 池连接-key: {self._pool._rused}, 池key: {self._pool._keys}')
logger.info(str(self._pool._pool))
except Exception as err:
logger.error(str(err))
logger.error(traceback.format_exc())
if ok:
break
else:
err_msg = f"pg库请求失败 累计重试{max_retry_cnt}次"
logger.error(err_msg)
return None
if len(rows) < 1:
return None
ret = pd.DataFrame(
rows,
columns=col_names,
)
ret = ret.astype(dtype=data_type)
return ret
def close(self):
self._pool.closeall()
self._pool = None
logger.info("数据库连接清理完毕!!!")
rds_io_ins = RDS_IO()
rds_host = "192.168.XXX.XXX"
rds_port = 5432
rds_dbname = "XXX"
rds_user = "XXX"
rds_password = "XXX"
ddd = {1:1, 2:2, 3:3}
logger.info("=======================准备初始化连接池")
rds_io_ins.init_pg(host=rds_host, port=rds_port, user=rds_user, password=rds_password, dbname=rds_dbname, minconn=2, maxconn=10)
def new_process(f):
@wraps(f)
def func_wrapper(*args, **kwargs):
result = f(*args, **kwargs)
return result
return func_wrapper
@concurrent.process(timeout=None, context=mp.get_context('fork'))
@new_process
def select_test():
for i in range(0,60):
logger.info(f"第{i}次执行sql")
data = rds_io_ins.execute_sql("select * from road_segment_dongguan limit 500")
logger.info(f"{data}")
def action():
future = select_test()
result = future.result()
logger.info(f"{result}")
def timed_select_test():
return timing.timeout(20)(select_test)
from functools import wraps
from math import radians
from threading import Thread
import os
from multiprocessing import managers, Value
from loguru import logger
from rds_io import rds_io_ins
from rds_io import select_test
from rds_io import action
if __name__ == '__main__':
for i in range(10):
thread = Thread(
target=action,
args=(
),
)
thread.start()
# requirements.txt
loguru
pebble
psycopg2
pandas
|