scripts/lib/mongo.py

292 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pymongo
from lib.csv import Csv
from lib.util import *
from pymongo import ReturnDocument
class MongoDB:
def __init__(self, host='127.0.0.1', port=27017, db='test', collection='test', user=None, password=None):
""" 初始化MongoDB数据库和表的信息并连接数据库
:param host: 数据库实例ip
:param port: 数据库实例端口
:param db: 数据库名
:param collection: 表名
:param user: 用户名
:param password: 密码
"""
self.host = host
self.port = port
self.client = pymongo.MongoClient(host=self.host, port=self.port)
if (user is not None and user != '') and (password is not None and password != ''):
self.db = self.client.admin
self.db.authenticate(user, password)
self.db = self.client[db] # 数据库
self.collection = self.db[collection] # 表
if db not in self.client.list_database_names():
print("数据库不存在!")
if collection not in self.db.list_collection_names():
print("表不存在!")
def __str__(self):
"""数据库基本信息"""
db = self.db._Database__name
collection = self.collection._Collection__name
num = self.collection.find().count()
return "host: {}, port: {}\n数据库{}{}{}条数据".format(self.host, self.port, db, collection, num)
def __len__(self):
"""表的数据条数"""
return self.collection.find().count()
def close(self):
self.client.close()
def find(self, query, offset=-1, limit=0, sort_keys=None, projection=None):
if offset >= 0 and limit > 0 and sort_keys is not None:
return list(self.collection.find(query, projection).sort(sort_keys).skip(offset).limit(limit))
elif offset >= 0 and limit > 0:
return list(self.collection.find(query, projection).skip(offset).limit(limit))
elif limit > 0 and sort_keys is not None:
return list(self.collection.find(query, projection).sort(sort_keys).limit(limit))
elif limit > 0:
return list(self.collection.find(query, projection).limit(limit))
elif sort_keys is not None:
return list(self.collection.find(query, projection).sort(sort_keys))
return list(self.collection.find(query, projection))
def find_one(self, query, offset=0, sort_keys=None, projection=None):
item = None
items = self.find(query, offset, 1, sort_keys, projection)
if len(items) > 0:
item = items[0]
return item
def count(self, query):
return int(self.collection.find(query).count())
def distinct(self, query, filed='_id'):
return list(self.collection.find(query).distinct(filed))
def aggregate(self, query, group, sort=None, limit=None):
if sort and limit:
return list(
self.collection.aggregate([{'$match': query}, {'$group': group}, {'$sort': sort}, {'$limit': limit}]))
elif limit:
return list(self.collection.aggregate([{'$match': query}, {'$group': group}, {'$limit': limit}]))
elif sort:
return list(self.collection.aggregate([{'$match': query}, {'$group': group}, {'$sort': sort}]))
return list(self.collection.aggregate([{'$match': query}, {'$group': group}]))
def insert_many(self, items):
return self.collection.insert_many(items)
def insert_one(self, item):
return self.collection.insert_one(item)
def set_on_insert(self, query, set_on_insert):
return self.collection.update(query, {'$setOnInsert': set_on_insert}, upsert=True)
def upsert_query(self, query, up):
return self.collection.update(query, up, upsert=True)
def update_one(self, query, up):
return self.collection.update_one(query, up)
def upsert_one(self, query, up):
return self.collection.update_one(query, up, upsert=True)
def update_many(self, query, up):
return self.collection.update_many(query, up)
def delete_one(self, query):
return self.collection.delete_one(query)
def upsert_id(self, _id, up):
query = {'_id': _id}
return self.collection.find_one_and_update(query, up, upsert=True, return_document=ReturnDocument.AFTER)
def bulk_write(self, requests):
return self.collection.bulk_write(requests)
def scan(self, query, limit=5000, left=-1, right=-1, index_field='_id', index_field_typ=int, total=-1,
sort_keys=None, projection=None, id_typ=ObjectId, print_log=True):
if sort_keys is None:
sort_keys = [('_id', 1)]
if right > left > 0:
query[index_field] = {
'$gte': left,
'$lte': right,
}
if not sort_keys:
sort_keys = [(index_field, 1)]
elif right > 0:
query[index_field] = {
'$lte': right,
}
sort_keys = [(index_field, -1)]
elif left > 0:
query[index_field] = {
'$gte': left,
}
sort_keys = [(index_field, 1)]
if print_log:
print(query)
one = self.find_one(query, sort_keys=sort_keys, projection=projection)
if not one:
return
if index_field not in one.keys():
print('invalid index_field: {}'.format(index_field))
return
last_id = safe_get(one, index_field, index_field_typ, index_field_typ())
more = True
loop_cnt = 0
all_docs = []
repeat_map = {}
while more:
loop_cnt += 1
if right > left > 0:
query[index_field] = {
'$gte': last_id,
}
sort_keys = [(index_field, 1)]
elif right > 0:
query[index_field] = {
'$lte': last_id,
}
sort_keys = [(index_field, -1)]
elif left > 0:
query[index_field] = {
'$gte': last_id,
}
sort_keys = [(index_field, 1)]
else:
query[index_field] = {
'$gte': last_id,
}
sort_keys = [(index_field, 1)]
if print_log:
print('loop_query: {}, sort_keys: {}'.format(query, sort_keys))
docs = self.find(query, offset=0, limit=limit, sort_keys=sort_keys, projection=projection)
for d in docs:
last_id = safe_get(d, index_field, index_field_typ, index_field_typ())
if 0 < right < last_id:
more = False
continue
# 判重
_id = safe_get(d, '_id', id_typ, id_typ())
if _id in repeat_map.keys():
continue
all_docs.append(d)
repeat_map[_id] = 1
if 0 < total <= len(all_docs):
more = False
continue
if print_log:
print('loop_cnt: {}, next: {}, examined: {}, total: {}'.format(loop_cnt, last_id, len(docs),
len(all_docs)))
if len(docs) < limit:
more = False
return all_docs
def full_scan(self, id_typ=int, projection=None):
return self.scan({}, limit=10000, index_field_typ=id_typ, projection=projection)
'''
db.getCollection("total_journal_issue").group({
key: {journal_id:1},
reduce: function(curr,result){ //curr当前文档, result结果文档
if(typeof curr.article_count!="undefined"){ //判断article_count字段是否存在其他方式参见js语法
result.sum_article=curr.article_count+result.sum_article;
}},
initial: {sum_article:0} //sum_article字段初始化
})
相当于sql
Select journal_id,sum(article_count) as sum_article
From total_journal_issue
Group By journal_id
'''
def group(self, key, condition, initial, reducer):
return list(self.collection.group(key, condition, initial, reducer))
# 分页查询
'''
根据常用分页查询,可分为两种:
1. 提供一个查询语句,如{ct:{$gte: xxx, $lt: yy}},若该查询语句查询的数据条数非常多,则非常占用数据库资源,此类情况应该通过【数据库分页】来查询
2. 提供[某个字段]的list如mid_list、_id_list等若该查询中ids的长度过长也会占用数据库资源此类情况应该对【ids分页】来查询常用与查询【用户、评论、帖子】
'''
# 第1种分页查询通过query
def page_find_by_query(self, query, sort_keys=None, offset=0, page=200, file='', projection=None, print_log=True):
"""
file: 若要生成文件,传文件绝对路径;不生成不用传
"""
if not isinstance(query, dict):
return list()
ret_list = []
skip = offset
limit = page
more = True
csv_ = None
if file != '':
csv_ = Csv(file, ['datas'])
while more:
results = self.find(query, skip, limit, sort_keys, projection)
if print_log:
print('s: {}, e: {}'.format(skip, skip + limit))
ret_list.extend(results)
if csv_:
csv_.extend(results)
if len(results) < limit:
more = False
else:
skip += limit
return ret_list
# 第2种分页查询通过_list
def page_find_by_list(self, ids, field='_id', sort_keys=None, offset=0, page=100, page_db=200, file='',
projection=None, print_log=True):
"""
file: 若要生成文件,传文件绝对路径;不生成不用传
"""
if (not isinstance(ids, list)) or (not isinstance(field, str)):
return list()
ret_list = []
skip = offset
limit = page
total = len(ids)
csv_ = None
if file != '':
csv_ = Csv(file, ['datas'])
while True:
if skip + limit > total:
limit = total - skip
if skip >= total or limit <= 0:
break
tmp_ids = ids[skip: skip + limit]
query = {field: {'$in': tmp_ids}}
results = self.page_find_by_query(query, sort_keys, 0, page_db, '', projection, print_log=print_log)
ret_list.extend(results)
if csv_:
csv_.extend(results)
skip += limit
return ret_list