import pymongo from lib.csv import Csv from lib.util import * from pymongo import ReturnDocument class MongoDB: def __init__(self, host='127.0.0.1', port=27017, db='test', collection='test', user=None, password=None): """ 初始化MongoDB数据库和表的信息并连接数据库 :param host: 数据库实例ip :param port: 数据库实例端口 :param db: 数据库名 :param collection: 表名 :param user: 用户名 :param password: 密码 """ self.host = host self.port = port self.client = pymongo.MongoClient(host=self.host, port=self.port) if (user is not None and user != '') and (password is not None and password != ''): self.db = self.client.admin self.db.authenticate(user, password) self.db = self.client[db] # 数据库 self.collection = self.db[collection] # 表 if db not in self.client.list_database_names(): print("数据库不存在!") if collection not in self.db.list_collection_names(): print("表不存在!") def __str__(self): """数据库基本信息""" db = self.db._Database__name collection = self.collection._Collection__name num = self.collection.find().count() return "host: {}, port: {}\n数据库{} 表{} 共{}条数据".format(self.host, self.port, db, collection, num) def __len__(self): """表的数据条数""" return self.collection.find().count() def close(self): self.client.close() def find(self, query, offset=-1, limit=0, sort_keys=None, projection=None): if offset >= 0 and limit > 0 and sort_keys is not None: return list(self.collection.find(query, projection).sort(sort_keys).skip(offset).limit(limit)) elif offset >= 0 and limit > 0: return list(self.collection.find(query, projection).skip(offset).limit(limit)) elif limit > 0 and sort_keys is not None: return list(self.collection.find(query, projection).sort(sort_keys).limit(limit)) elif limit > 0: return list(self.collection.find(query, projection).limit(limit)) elif sort_keys is not None: return list(self.collection.find(query, projection).sort(sort_keys)) return list(self.collection.find(query, projection)) def find_one(self, query, offset=0, sort_keys=None, projection=None): item = None items = self.find(query, offset, 1, sort_keys, projection) if len(items) > 0: item = items[0] return item def count(self, query): return int(self.collection.find(query).count()) def distinct(self, query, filed='_id'): return list(self.collection.find(query).distinct(filed)) def aggregate(self, query, group, sort=None, limit=None): if sort and limit: return list( self.collection.aggregate([{'$match': query}, {'$group': group}, {'$sort': sort}, {'$limit': limit}])) elif limit: return list(self.collection.aggregate([{'$match': query}, {'$group': group}, {'$limit': limit}])) elif sort: return list(self.collection.aggregate([{'$match': query}, {'$group': group}, {'$sort': sort}])) return list(self.collection.aggregate([{'$match': query}, {'$group': group}])) def insert_many(self, items): return self.collection.insert_many(items) def insert_one(self, item): return self.collection.insert_one(item) def set_on_insert(self, query, set_on_insert): return self.collection.update(query, {'$setOnInsert': set_on_insert}, upsert=True) def upsert_query(self, query, up): return self.collection.update(query, up, upsert=True) def update_one(self, query, up): return self.collection.update_one(query, up) def upsert_one(self, query, up): return self.collection.update_one(query, up, upsert=True) def update_many(self, query, up): return self.collection.update_many(query, up) def delete_one(self, query): return self.collection.delete_one(query) def upsert_id(self, _id, up): query = {'_id': _id} return self.collection.find_one_and_update(query, up, upsert=True, return_document=ReturnDocument.AFTER) def bulk_write(self, requests): return self.collection.bulk_write(requests) def scan(self, query, limit=5000, left=-1, right=-1, index_field='_id', index_field_typ=int, total=-1, sort_keys=None, projection=None, id_typ=ObjectId, print_log=True): if sort_keys is None: sort_keys = [('_id', 1)] if right > left > 0: query[index_field] = { '$gte': left, '$lte': right, } if not sort_keys: sort_keys = [(index_field, 1)] elif right > 0: query[index_field] = { '$lte': right, } sort_keys = [(index_field, -1)] elif left > 0: query[index_field] = { '$gte': left, } sort_keys = [(index_field, 1)] if print_log: print(query) one = self.find_one(query, sort_keys=sort_keys, projection=projection) if not one: return if index_field not in one.keys(): print('invalid index_field: {}'.format(index_field)) return last_id = safe_get(one, index_field, index_field_typ, index_field_typ()) more = True loop_cnt = 0 all_docs = [] repeat_map = {} while more: loop_cnt += 1 if right > left > 0: query[index_field] = { '$gte': last_id, } sort_keys = [(index_field, 1)] elif right > 0: query[index_field] = { '$lte': last_id, } sort_keys = [(index_field, -1)] elif left > 0: query[index_field] = { '$gte': last_id, } sort_keys = [(index_field, 1)] else: query[index_field] = { '$gte': last_id, } sort_keys = [(index_field, 1)] if print_log: print('loop_query: {}, sort_keys: {}'.format(query, sort_keys)) docs = self.find(query, offset=0, limit=limit, sort_keys=sort_keys, projection=projection) for d in docs: last_id = safe_get(d, index_field, index_field_typ, index_field_typ()) if 0 < right < last_id: more = False continue # 判重 _id = safe_get(d, '_id', id_typ, id_typ()) if _id in repeat_map.keys(): continue all_docs.append(d) repeat_map[_id] = 1 if 0 < total <= len(all_docs): more = False continue if print_log: print('loop_cnt: {}, next: {}, examined: {}, total: {}'.format(loop_cnt, last_id, len(docs), len(all_docs))) if len(docs) < limit: more = False return all_docs def full_scan(self, id_typ=int, projection=None): return self.scan({}, limit=10000, index_field_typ=id_typ, projection=projection) ''' db.getCollection("total_journal_issue").group({ key: {journal_id:1}, reduce: function(curr,result){ //curr当前文档, result结果文档 if(typeof curr.article_count!="undefined"){ //判断article_count字段是否存在,其他方式参见js语法 result.sum_article=curr.article_count+result.sum_article; }}, initial: {sum_article:0} //sum_article字段初始化 }) 相当于sql Select journal_id,sum(article_count) as sum_article From total_journal_issue Group By journal_id ''' def group(self, key, condition, initial, reducer): return list(self.collection.group(key, condition, initial, reducer)) # 分页查询 ''' 根据常用分页查询,可分为两种: 1. 提供一个查询语句,如{ct:{$gte: xxx, $lt: yy}},若该查询语句查询的数据条数非常多,则非常占用数据库资源,此类情况应该通过【数据库分页】来查询 2. 提供[某个字段]的list,如mid_list、_id_list等,若该查询中ids的长度过长,也会占用数据库资源,此类情况应该对【ids分页】来查询,常用与查询【用户、评论、帖子】 ''' # 第1种分页查询,通过query def page_find_by_query(self, query, sort_keys=None, offset=0, page=200, file='', projection=None, print_log=True): """ file: 若要生成文件,传文件绝对路径;不生成不用传 """ if not isinstance(query, dict): return list() ret_list = [] skip = offset limit = page more = True csv_ = None if file != '': csv_ = Csv(file, ['datas']) while more: results = self.find(query, skip, limit, sort_keys, projection) if print_log: print('s: {}, e: {}'.format(skip, skip + limit)) ret_list.extend(results) if csv_: csv_.extend(results) if len(results) < limit: more = False else: skip += limit return ret_list # 第2种分页查询,通过_list def page_find_by_list(self, ids, field='_id', sort_keys=None, offset=0, page=100, page_db=200, file='', projection=None, print_log=True): """ file: 若要生成文件,传文件绝对路径;不生成不用传 """ if (not isinstance(ids, list)) or (not isinstance(field, str)): return list() ret_list = [] skip = offset limit = page total = len(ids) csv_ = None if file != '': csv_ = Csv(file, ['datas']) while True: if skip + limit > total: limit = total - skip if skip >= total or limit <= 0: break tmp_ids = ids[skip: skip + limit] query = {field: {'$in': tmp_ids}} results = self.page_find_by_query(query, sort_keys, 0, page_db, '', projection, print_log=print_log) ret_list.extend(results) if csv_: csv_.extend(results) skip += limit return ret_list