292 lines
11 KiB
Python
292 lines
11 KiB
Python
import pymongo
|
||
from lib.csv import Csv
|
||
from lib.util import *
|
||
from pymongo import ReturnDocument
|
||
|
||
|
||
class MongoDB:
|
||
def __init__(self, host='127.0.0.1', port=27017, db='test', collection='test', user=None, password=None):
|
||
""" 初始化MongoDB数据库和表的信息并连接数据库
|
||
:param host: 数据库实例ip
|
||
:param port: 数据库实例端口
|
||
:param db: 数据库名
|
||
:param collection: 表名
|
||
:param user: 用户名
|
||
:param password: 密码
|
||
"""
|
||
|
||
self.host = host
|
||
self.port = port
|
||
self.client = pymongo.MongoClient(host=self.host, port=self.port)
|
||
|
||
if (user is not None and user != '') and (password is not None and password != ''):
|
||
self.db = self.client.admin
|
||
self.db.authenticate(user, password)
|
||
|
||
self.db = self.client[db] # 数据库
|
||
self.collection = self.db[collection] # 表
|
||
|
||
if db not in self.client.list_database_names():
|
||
print("数据库不存在!")
|
||
if collection not in self.db.list_collection_names():
|
||
print("表不存在!")
|
||
|
||
def __str__(self):
|
||
"""数据库基本信息"""
|
||
db = self.db._Database__name
|
||
collection = self.collection._Collection__name
|
||
num = self.collection.find().count()
|
||
return "host: {}, port: {}\n数据库{} 表{} 共{}条数据".format(self.host, self.port, db, collection, num)
|
||
|
||
def __len__(self):
|
||
"""表的数据条数"""
|
||
return self.collection.find().count()
|
||
|
||
def close(self):
|
||
self.client.close()
|
||
|
||
def find(self, query, offset=-1, limit=0, sort_keys=None, projection=None):
|
||
if offset >= 0 and limit > 0 and sort_keys is not None:
|
||
return list(self.collection.find(query, projection).sort(sort_keys).skip(offset).limit(limit))
|
||
elif offset >= 0 and limit > 0:
|
||
return list(self.collection.find(query, projection).skip(offset).limit(limit))
|
||
elif limit > 0 and sort_keys is not None:
|
||
return list(self.collection.find(query, projection).sort(sort_keys).limit(limit))
|
||
elif limit > 0:
|
||
return list(self.collection.find(query, projection).limit(limit))
|
||
elif sort_keys is not None:
|
||
return list(self.collection.find(query, projection).sort(sort_keys))
|
||
return list(self.collection.find(query, projection))
|
||
|
||
def find_one(self, query, offset=0, sort_keys=None, projection=None):
|
||
item = None
|
||
items = self.find(query, offset, 1, sort_keys, projection)
|
||
if len(items) > 0:
|
||
item = items[0]
|
||
return item
|
||
|
||
def count(self, query):
|
||
return int(self.collection.find(query).count())
|
||
|
||
def distinct(self, query, filed='_id'):
|
||
return list(self.collection.find(query).distinct(filed))
|
||
|
||
def aggregate(self, query, group, sort=None, limit=None):
|
||
if sort and limit:
|
||
return list(
|
||
self.collection.aggregate([{'$match': query}, {'$group': group}, {'$sort': sort}, {'$limit': limit}]))
|
||
elif limit:
|
||
return list(self.collection.aggregate([{'$match': query}, {'$group': group}, {'$limit': limit}]))
|
||
elif sort:
|
||
return list(self.collection.aggregate([{'$match': query}, {'$group': group}, {'$sort': sort}]))
|
||
return list(self.collection.aggregate([{'$match': query}, {'$group': group}]))
|
||
|
||
def insert_many(self, items):
|
||
return self.collection.insert_many(items)
|
||
|
||
def insert_one(self, item):
|
||
return self.collection.insert_one(item)
|
||
|
||
def set_on_insert(self, query, set_on_insert):
|
||
return self.collection.update(query, {'$setOnInsert': set_on_insert}, upsert=True)
|
||
|
||
def upsert_query(self, query, up):
|
||
return self.collection.update(query, up, upsert=True)
|
||
|
||
def update_one(self, query, up):
|
||
return self.collection.update_one(query, up)
|
||
|
||
def upsert_one(self, query, up):
|
||
return self.collection.update_one(query, up, upsert=True)
|
||
|
||
def update_many(self, query, up):
|
||
return self.collection.update_many(query, up)
|
||
|
||
def delete_one(self, query):
|
||
return self.collection.delete_one(query)
|
||
|
||
def upsert_id(self, _id, up):
|
||
query = {'_id': _id}
|
||
return self.collection.find_one_and_update(query, up, upsert=True, return_document=ReturnDocument.AFTER)
|
||
|
||
def bulk_write(self, requests):
|
||
return self.collection.bulk_write(requests)
|
||
|
||
def scan(self, query, limit=5000, left=-1, right=-1, index_field='_id', index_field_typ=int, total=-1,
|
||
sort_keys=None, projection=None, id_typ=ObjectId, print_log=True):
|
||
if sort_keys is None:
|
||
sort_keys = [('_id', 1)]
|
||
|
||
if right > left > 0:
|
||
query[index_field] = {
|
||
'$gte': left,
|
||
'$lte': right,
|
||
}
|
||
if not sort_keys:
|
||
sort_keys = [(index_field, 1)]
|
||
elif right > 0:
|
||
query[index_field] = {
|
||
'$lte': right,
|
||
}
|
||
sort_keys = [(index_field, -1)]
|
||
elif left > 0:
|
||
query[index_field] = {
|
||
'$gte': left,
|
||
}
|
||
sort_keys = [(index_field, 1)]
|
||
|
||
if print_log:
|
||
print(query)
|
||
one = self.find_one(query, sort_keys=sort_keys, projection=projection)
|
||
if not one:
|
||
return
|
||
if index_field not in one.keys():
|
||
print('invalid index_field: {}'.format(index_field))
|
||
return
|
||
last_id = safe_get(one, index_field, index_field_typ, index_field_typ())
|
||
|
||
more = True
|
||
loop_cnt = 0
|
||
all_docs = []
|
||
repeat_map = {}
|
||
while more:
|
||
loop_cnt += 1
|
||
if right > left > 0:
|
||
query[index_field] = {
|
||
'$gte': last_id,
|
||
}
|
||
sort_keys = [(index_field, 1)]
|
||
elif right > 0:
|
||
query[index_field] = {
|
||
'$lte': last_id,
|
||
}
|
||
sort_keys = [(index_field, -1)]
|
||
elif left > 0:
|
||
query[index_field] = {
|
||
'$gte': last_id,
|
||
}
|
||
sort_keys = [(index_field, 1)]
|
||
else:
|
||
query[index_field] = {
|
||
'$gte': last_id,
|
||
}
|
||
sort_keys = [(index_field, 1)]
|
||
if print_log:
|
||
print('loop_query: {}, sort_keys: {}'.format(query, sort_keys))
|
||
docs = self.find(query, offset=0, limit=limit, sort_keys=sort_keys, projection=projection)
|
||
|
||
for d in docs:
|
||
last_id = safe_get(d, index_field, index_field_typ, index_field_typ())
|
||
if 0 < right < last_id:
|
||
more = False
|
||
continue
|
||
|
||
# 判重
|
||
_id = safe_get(d, '_id', id_typ, id_typ())
|
||
if _id in repeat_map.keys():
|
||
continue
|
||
all_docs.append(d)
|
||
repeat_map[_id] = 1
|
||
if 0 < total <= len(all_docs):
|
||
more = False
|
||
continue
|
||
if print_log:
|
||
print('loop_cnt: {}, next: {}, examined: {}, total: {}'.format(loop_cnt, last_id, len(docs),
|
||
len(all_docs)))
|
||
|
||
if len(docs) < limit:
|
||
more = False
|
||
|
||
return all_docs
|
||
|
||
def full_scan(self, id_typ=int, projection=None):
|
||
return self.scan({}, limit=10000, index_field_typ=id_typ, projection=projection)
|
||
|
||
'''
|
||
db.getCollection("total_journal_issue").group({
|
||
key: {journal_id:1},
|
||
reduce: function(curr,result){ //curr当前文档, result结果文档
|
||
if(typeof curr.article_count!="undefined"){ //判断article_count字段是否存在,其他方式参见js语法
|
||
result.sum_article=curr.article_count+result.sum_article;
|
||
}},
|
||
initial: {sum_article:0} //sum_article字段初始化
|
||
})
|
||
相当于sql
|
||
Select journal_id,sum(article_count) as sum_article
|
||
From total_journal_issue
|
||
Group By journal_id
|
||
'''
|
||
|
||
def group(self, key, condition, initial, reducer):
|
||
return list(self.collection.group(key, condition, initial, reducer))
|
||
|
||
# 分页查询
|
||
'''
|
||
根据常用分页查询,可分为两种:
|
||
1. 提供一个查询语句,如{ct:{$gte: xxx, $lt: yy}},若该查询语句查询的数据条数非常多,则非常占用数据库资源,此类情况应该通过【数据库分页】来查询
|
||
2. 提供[某个字段]的list,如mid_list、_id_list等,若该查询中ids的长度过长,也会占用数据库资源,此类情况应该对【ids分页】来查询,常用与查询【用户、评论、帖子】
|
||
'''
|
||
|
||
# 第1种分页查询,通过query
|
||
def page_find_by_query(self, query, sort_keys=None, offset=0, page=200, file='', projection=None, print_log=True):
|
||
"""
|
||
file: 若要生成文件,传文件绝对路径;不生成不用传
|
||
"""
|
||
if not isinstance(query, dict):
|
||
return list()
|
||
|
||
ret_list = []
|
||
skip = offset
|
||
limit = page
|
||
more = True
|
||
csv_ = None
|
||
if file != '':
|
||
csv_ = Csv(file, ['datas'])
|
||
|
||
while more:
|
||
results = self.find(query, skip, limit, sort_keys, projection)
|
||
if print_log:
|
||
print('s: {}, e: {}'.format(skip, skip + limit))
|
||
ret_list.extend(results)
|
||
if csv_:
|
||
csv_.extend(results)
|
||
|
||
if len(results) < limit:
|
||
more = False
|
||
else:
|
||
skip += limit
|
||
return ret_list
|
||
|
||
# 第2种分页查询,通过_list
|
||
def page_find_by_list(self, ids, field='_id', sort_keys=None, offset=0, page=100, page_db=200, file='',
|
||
projection=None, print_log=True):
|
||
"""
|
||
file: 若要生成文件,传文件绝对路径;不生成不用传
|
||
"""
|
||
if (not isinstance(ids, list)) or (not isinstance(field, str)):
|
||
return list()
|
||
|
||
ret_list = []
|
||
skip = offset
|
||
limit = page
|
||
total = len(ids)
|
||
csv_ = None
|
||
if file != '':
|
||
csv_ = Csv(file, ['datas'])
|
||
|
||
while True:
|
||
if skip + limit > total:
|
||
limit = total - skip
|
||
if skip >= total or limit <= 0:
|
||
break
|
||
tmp_ids = ids[skip: skip + limit]
|
||
|
||
query = {field: {'$in': tmp_ids}}
|
||
results = self.page_find_by_query(query, sort_keys, 0, page_db, '', projection, print_log=print_log)
|
||
ret_list.extend(results)
|
||
if csv_:
|
||
csv_.extend(results)
|
||
|
||
skip += limit
|
||
return ret_list
|