通常比较常用的连接池是数据库连接池,HTTP Client连接池,我也自己编写过连接池,如Thrift连接池及插入Rabbitmq队列的连接池。
首先剖析一下数据库连接池的设计与实现的原理。DBUtils 属于数据库连接池实现模块,用于连接DB-API 2模块,对数据库连接线程化,使可以安全和高效的访问数据库的模块。本文主要分析一下PooledDB的流程。
DBUtils.PooledDB使用DB-API 2模块实现了一个强硬的、线程安全的、有缓存的、可复用的数据库连接。
本文主要考虑dedicated connections,即专用数据库连接,在初始化时连接池时,就需要指定mincached、maxcached以及maxconnections等参数,分别表示连接池的最小连接数、连接池的最大连接数以及系统可用的最大连接数,同时,blocking参数表征了当获取不到连接的时候是阻塞等待获取连接还是返回异常:
if not blocking: def wait(): raise TooManyConnections self._condition.wait = wait
# Establish an initial number of idle database connections:idle = [self.dedicated_connection() for i in range(mincached)]while idle: idle.pop().close()
def close(self): """Close the pooled dedicated connection.""" # Instead of actually closing the connection, # return it to the pool for future reuse. if self._con: self._pool.cache(self._con) self._con = None
def cache(self, con): """Put a dedicated connection back into the idle cache.""" self._condition.acquire() try: if not self._maxcached or len(self._idle_cache) < self._maxcached: con._reset(force=self._reset) # rollback possible transaction # the idle cache is not full, so put it there self._idle_cache.append(con) # append it to the idle cache else: # if the idle cache is already full, con.close() # then close the connection self._connections -= 1 self._condition.notify() finally: self._condition.release()
# try to get a dedicated connection self._condition.acquire() try: while (self._maxconnections and self._connections >= self._maxconnections): self._condition.wait() # connection limit not reached, get a dedicated connection try: # first try to get it from the idle cache con = self._idle_cache.pop(0) except IndexError: # else get a fresh connection con = self.steady_connection() else: con._ping_check() # check connection con = PooledDedicatedDBConnection(self, con) self._connections += 1 finally: self._condition.release()
# coding:utf-8import loggingimport threadingimport Queuefrom kombu import Connectionimport timeclass InsertQueue(): def __init__(self, host=None, port=None, virtual_host=None, heartbeat_interval=3, name=None, password=None, logger=None, maxIdle=10, maxActive=50, timeout=30, disable_time=20): """ :param str host: Hostname or IP Address to connect to :param int port: TCP port to connect to :param str virtual_host: RabbitMQ virtual host to use :param int heartbeat_interval: How often to send heartbeats :param str name: auth credentials name :param str password: auth credentials password """ self.logger = logging if logger is None else logger self.host = host self.port = port self.virtual_host = virtual_host self.heartbeat_interval = heartbeat_interval self.name = name self.password = password self.mutex = threading.RLock() self.maxIdle = maxIdle self.maxActive = maxActive self.available = self.maxActive self.timeout = timeout self._queue = Queue.Queue(maxsize=self.maxIdle) self.disable_time = disable_time def get_new_connection_pipe(self): """ 产生新的队列连接 :return: """ with self.mutex: if self.available <= 0: raise GetConnectionException self.available -= 1 try: conn = Connection(hostname=self.host, port=self.port, virtual_host=self.virtual_host, heartbeat=self.heartbeat_interval, userid=self.name, password=self.password) producer = conn.Producer() return ConnectionPipe(conn, producer) except: with self.mutex: self.available += 1 raise GetConnectionException def get_connection_pipe(self): """ 获取连接 :return: """ try: connection_pipe = self._queue.get(False) except Queue.Empty: try: connection_pipe = self.get_new_connection_pipe() except GetConnectionException: timeout = self.timeout try: connection_pipe = self._queue.get(timeout=timeout) except Queue.Empty: try: connection_pipe = self.get_new_connection_pipe() except GetConnectionException: logging.error("Too much connections, Get Connection Timeout!") if (time.time() - connection_pipe.use_time) > self.disable_time: self.close(connection_pipe) return self.get_connection_pipe() return connection_pipe def close(self, connection_pipe): """ close the connection and the correlative channel :param connection_pipe: :return: """ with self.mutex: self.available += 1 connection_pipe.close() return def insert_message(self, exchange=None, body=None, routing_key='', mandatory=True): """ insert message to queue :param str exchange: exchange name :param str body: message :param str routing_key: routing key :param bool mandatory: is confirm: True means confirm, False means not confirm :return: """ put_into_queue_flag = True insert_result = False connection_pipe = None try: connection_pipe = self.get_connection_pipe() producer = connection_pipe.channel use_time = time.time() producer.publish(exchange=exchange, body=body, delivery_mode=2, routing_key=routing_key, mandatory=mandatory ) insert_result = True except Exception: insert_result = False put_into_queue_flag = False finally: if put_into_queue_flag is True: try: connection_pipe.use_time = use_time self._queue.put_nowait(connection_pipe) except Queue.Full: self.close(connection_pipe) else: if connection_pipe is not None: self.close(connection_pipe) return insert_resultclass ConnectionPipe(object): """ connection和channel对象的封装 """ def __init__(self, connection, channel): self.connection = connection self.channel = channel self.use_time = time.time() def close(self): try: self.connection.close() except Exception as ex: passclass GetConnectionException(): """ 获取连接异常 """ pass
# coding: utf-8import threadingfrom collections import dequeimport loggingimport socketimport timefrom kazoo.client import KazooClientfrom thriftpy.protocol import TBinaryProtocolFactoryfrom thriftpy.transport import ( TBufferedTransportFactory, TSocket,)from gevent.event import AsyncResultfrom gevent import Timeoutfrom error import CTECThriftClientErrorfrom thriftpy.thrift import TClientfrom thriftpy.transport import TTransportExceptionclass ClientPool: def __init__(self, service, server_hosts=None, zk_path=None, zk_hosts=None, logger=None, max_renew_times=3, maxActive=20, maxIdle=10, get_connection_timeout=30, socket_timeout=30, disable_time=3): """ :param service: Thrift的Service名称 :param server_hosts: 服务提供者地址,数组类型,['ip:port','ip:port'] :param zk_path: 服务提供者在zookeeper中的路径 :param zk_hosts: zookeeper的host地址,多个请用逗号隔开 :param max_renew_times: 最大重连次数 :param maxActive: 最大连接数 :param maxIdle: 最大空闲连接数 :param get_connection_timeout:获取连接的超时时间 :param socket_timeout: 读取数据的超时时间 :param disable_time: 连接失效时间 """ # 负载均衡队列 self.load_balance_queue = deque() self.service = service self.lock = threading.RLock() self.max_renew_times = max_renew_times self.maxActive = maxActive self.maxIdle = maxIdle self.connections = set() self.pool_size = 0 self.get_connection_timeout = get_connection_timeout self.no_client_queue = deque() self.socket_timeout = socket_timeout self.disable_time = disable_time self.logger = logging if logger is None else logger if zk_hosts: self.kazoo_client = KazooClient(hosts=zk_hosts) self.kazoo_client.start() self.zk_path = zk_path self.zk_hosts = zk_hosts # 定义Watcher self.kazoo_client.ChildrenWatch(path=self.zk_path, func=self.watcher) # 刷新连接池中的连接对象 self.__refresh_thrift_connections(self.kazoo_client.get_children(self.zk_path)) elif server_hosts: self.server_hosts = server_hosts # 复制新的IP地址到负载均衡队列中 self.load_balance_queue.extendleft(self.server_hosts) else: raise CTECThriftClientError('没有指定服务器获取方式!') def get_new_client(self): """ 轮询在每个ip:port的连接池中获取连接(线程安全) 从当前队列右侧取出ip:port信息,获取client 将连接池对象放回到当前队列的左侧 请求或连接超时时间,默认30秒 :return: """ with self.lock: if self.pool_size < self.maxActive: try: ip = self.load_balance_queue.pop() except IndexError: raise CTECThriftClientError('没有可用的服务提供者列表!') if ip: self.load_balance_queue.appendleft(ip) # 创建新的thrift client t_socket = TSocket(ip.split(':')[0], int(ip.split(':')[1]), socket_timeout=1000 * self.socket_timeout) proto_factory = TBinaryProtocolFactory() trans_factory = TBufferedTransportFactory() transport = trans_factory.get_transport(t_socket) protocol = proto_factory.get_protocol(transport) transport.open() client = TClient(self.service, protocol) self.pool_size += 1 return client else: return None def close(self): """ 关闭所有连接池和zk客户端 :return: """ if getattr(self, 'kazoo_client', None): self.kazoo_client.stop() def watcher(self, children): """ zk的watcher方法,负责检测zk的变化,刷新当前双端队列中的连接池 :param children: 子节点,即服务提供方的列表 :return: """ self.__refresh_thrift_connections(children) def __refresh_thrift_connections(self, children): """ 刷新服务提供者在当前队列中的连接池信息(线程安全),主要用于zk刷新 :param children: :return: """ with self.lock: # 清空负载均衡队列 self.load_balance_queue.clear() # 清空连接池 self.connections.clear() # 复制新的IP地址到负载均衡队列中 self.load_balance_queue.extendleft(children) def __getattr__(self, name): """ 函数调用,最大重试次数为max_renew_times :param name: :return: """ def method(*args, **kwds): # 从连接池获取连接 client = self.get_client_from_pool() # 连接池中无连接 if client is None: # 设置获取连接的超时时间 time_out = Timeout(self.get_connection_timeout) time_out.start() try: async_result = AsyncResult() self.no_client_queue.appendleft(async_result) client = async_result.get() # blocking except: with self.lock: if client is None: self.no_client_queue.remove(async_result) self.logger.error("Get Connection Timeout!") finally: time_out.cancel() if client is not None: for i in xrange(self.max_renew_times): try: put_back_flag = True client.last_use_time = time.time() fun = getattr(client, name, None) return fun(*args, **kwds) except socket.timeout: self.logger.error("Socket Timeout!") # 关闭连接,不关闭会导致乱序 put_back_flag = False self.close_one_client(client) break except TTransportException, e: put_back_flag = False if e.type == TTransportException.END_OF_FILE: self.logger.warning("Socket Connection Reset Error,%s", e) with self.lock: client.close() self.pool_size -= 1 client = self.get_new_client() else: self.logger.error("Socket Error,%s", e) self.close_one_client(client) break except socket.error, e: put_back_flag = False if e.errno == socket.errno.ECONNABORTED: self.logger.warning("Socket Connection aborted Error,%s", e) with self.lock: client.close() self.pool_size -= 1 client = self.get_new_client() else: self.logger.error("Socket Error, %s", e) self.close_one_client(client) break except Exception as e: put_back_flag = False self.logger.error("Thrift Error, %s", e) self.close_one_client(client) break finally: # 将连接放回连接池 if put_back_flag is True: self.put_back_connections(client) return None return method def close_one_client(self, client): """ 线程安全 关闭连接 :param client: :return: """ with self.lock: client.close() self.pool_size -= 1 def put_back_connections(self, client): """ 线程安全 将连接放回连接池,逻辑如下: 1、如果有请求尚未获取到连接,请求优先 2、如果连接池中的连接的数目小于maxIdle,则将该连接放回连接池 3、关闭连接 :param client: :return: """ with self.lock: if self.no_client_queue.__len__() > 0: task = self.no_client_queue.pop() task.set(client) elif self.connections.__len__() < self.maxIdle: self.connections.add(client) else: client.close() self.pool_size -= 1 def get_client_from_pool(self): """ 线程安全 从连接池中获取连接,若连接池中有连接,直接取出,否则, 新建一个连接,若一直无法获取连接,则返回None :return: """ client = self.get_one_client_from_pool() if client is not None and (time.time() - client.last_use_time) < self.disable_time: return client else: if client is not None: self.close_one_client(client) client = self.get_new_client() if client is not None: return client return None def get_one_client_from_pool(self): """ 线程安全 从连接池中获取一个连接,若取不到连接,则返回None :return: """ with self.lock: if self.connections: try: return self.connections.pop() except KeyError: return None return None