import os
import json
from collections import defaultdict, OrderedDict
[docs]class BaseDBConnection(object):
""" Base KG connection for database
"""
def __init__(self, db_path, chunksize):
""" Create an connection to database
:param db_path: database path
:type db_path: str
:param chunksize: the chunksize to load/write database
:type chunksize: int
"""
self._conn = None
self.chunksize = chunksize
[docs] def close(self):
""" Close the connection safely
"""
raise NotImplementedError
def __del__(self):
self.close()
[docs] def create_table(self, table_name, columns, column_types):
""" Create a table with given columns and types
:param table_name: the table name to create
:type table_name: str
:param columns: the columns to create
:type columns: List[str]
:param column_types: the corresponding column types
:type column_types: List[str]
"""
raise NotImplementedError
[docs] def get_columns(self, table_name, columns):
""" Get column information from a table
:param table_name: the table name to retrieve
:type table_name: str
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
raise NotImplementedError
[docs] def select_row(self, table_name, _id, columns):
""" Select a row from a table
:param table_name: the table name to retrieve
:type table_name: str
:param _id: the row id
:type _id: str
:param columns: the columns to retrieve
:type columns: List[str]
:return: a retrieved row
:rtype: Dict[str, object]
"""
raise NotImplementedError
[docs] def select_rows(self, table_name, _ids, columns):
""" Select rows from a table
:param table_name: the table name to retrieve
:type table_name: str
:param _ids: the row ids
:type _ids: List[str]
:param columns: the columns to retrieve
:type columns: List[str]
:return: retrieved rows
:rtype: List[Dict[str, object]]
"""
raise NotImplementedError
[docs] def insert_row(self, table_name, row):
""" Insert a row into a table
:param table_name: the table name to insert
:type table_name: str
:param row: the row to insert
:type row: Dict[str, object]
"""
raise NotImplementedError
[docs] def insert_rows(self, table_name, rows):
""" Insert several rows into a table
:param table_name: the table name to insert
:type table_name: str
:param rows: the rows to insert
:type rows: List[Dict[str, object]]
"""
raise NotImplementedError
[docs] def get_update_op(self, update_columns, operator):
""" Get an update operator based on columns and a operator
:param update_columns: a list of columns to update
:type update_columns: List[str]
:param operator: an operator that applies to the columns, including "+", "-", "*", "/", "="
:type operator: str
:return: an operator that suits the backend database
:rtype: object
"""
raise NotImplementedError
[docs] def update_row(self, table_name, row, update_op, update_columns):
""" Update a row that exists in a table
:param table_name: the table name to update
:type table_name: str
:param row: a new row
:type row: Dict[str, object]
:param update_op: an operator that returned by `get_update_op`
:type update_op: object
:param update_columns: the columns to update
:type update_columns: List[str]
"""
raise NotImplementedError
[docs] def update_rows(self, table_name, rows, update_ops, update_columns):
""" Update rows that exist in a table
:param table_name: the table name to update
:type table_name: str
:param rows: new rows
:type rows: List[Dict[str, object]]
:param update_ops: operator(s) that returned by `get_update_op`
:type update_ops: Union[List[object], object]
:param update_columns: the columns to update
:type update_columns: List[str]
"""
raise NotImplementedError
[docs] def get_rows_by_keys(self, table_name, bys, keys, columns, order_bys=None, reverse=False, top_n=None):
""" Retrieve rows by specific keys in some order
:param table_name: the table name to retrieve
:type table_name: str
:param bys: the given columns to match
:type bys: List[str]
:param keys: the given values to match
:type keys: List[str]
:param columns: the given columns to retrieve
:type columns: List[str]
:param order_bys: the columns whose value are used to sort rows
:type order_bys: List[str]
:param reverse: whether to sort in a reversed order
:type reverse: bool
:param top_n: how many rows to return, default `None` for all rows
:type top_n: int
:return: retrieved rows
:rtype: List[Dict[str, object]]
"""
raise NotImplementedError
[docs]class SqliteDBConnection(BaseDBConnection):
""" KG connection for SQLite database
"""
def __init__(self, db_path, chunksize):
""" Create an connection to SQLite database
:param db_path: database path, e.g., /home/xliucr/ASER/KG.db
:type db_path: str
:param chunksize: the chunksize to load/write database
:type chunksize: int
"""
import sqlite3
super(SqliteDBConnection, self).__init__(db_path, chunksize)
self._conn = sqlite3.connect(db_path)
[docs] def close(self):
""" Close the connection safely
"""
if self._conn:
self._conn.close()
[docs] def create_table(self, table_name, columns, column_types):
""" Create a table with given columns and types
:param table_name: the table name to create
:type table_name: str
:param columns: the columns to create
:type columns: List[str]
:param column_types: the corresponding column types, please refer to https://www.sqlite.org/datatype3.html
:type column_types: List[str]
"""
create_table = "CREATE TABLE %s (%s);" % (
table_name, ",".join([' '.join(x) for x in zip(columns, column_types)])
)
self._conn.execute(create_table)
self._conn.commit()
[docs] def get_columns(self, table_name, columns):
""" Get column information from a table
:param table_name: the table name to retrieve
:type table_name: str
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
select_table = "SELECT %s FROM %s;" % (",".join(columns), table_name)
result = list(map(lambda x: OrderedDict(zip(columns, x)), self._conn.execute(select_table)))
return result
[docs] def select_row(self, table_name, _id, columns):
""" Select a row from a table
(suggestion: consider to use `select_rows` if you want to retrieve multiple rows)
:param table_name: the table name to retrieve
:type table_name: str
:param _id: the row id
:type _id: str
:param columns: the columns to retrieve
:type columns: List[str]
:return: a retrieved row
:rtype: Dict[str, object]
"""
select_table = "SELECT %s FROM %s WHERE _id=?;" % (",".join(columns), table_name)
result = list(self._conn.execute(select_table, [_id]))
if len(result) == 0:
return None
else:
return OrderedDict(zip(columns, result[0]))
[docs] def select_rows(self, table_name, _ids, columns):
""" Select rows from a table
:param table_name: the table name to retrieve
:type table_name: str
:param _ids: the row ids
:type _ids: List[str]
:param columns: the columns to retrieve
:type columns: List[str]
:return: retrieved rows
:rtype: List[Dict[str, object]]
"""
if len(_ids) > 0:
row_cache = dict()
result = []
for idx in range(0, len(_ids), self.chunksize):
select_table = "SELECT %s FROM %s WHERE _id IN ('%s');" % (
",".join(columns), table_name, "','".join(_ids[idx:idx + self.chunksize])
)
result.extend(list(self._conn.execute(select_table)))
for x in result:
exact_match_row = OrderedDict(zip(columns, x))
row_cache[exact_match_row["_id"]] = exact_match_row
exact_match_rows = []
for _id in _ids:
exact_match_rows.append(row_cache.get(_id, None))
return exact_match_rows
else:
return []
[docs] def insert_row(self, table_name, row):
""" Insert a row into a table
(suggestion: consider to use `insert_rows` if you want to insert multiple rows)
:param table_name: the table name to insert
:type table_name: str
:param row: the row to insert
:type row: Dict[str, object]
"""
insert_table = "INSERT INTO %s VALUES (%s)" % (table_name, ",".join(['?'] * (len(row))))
self._conn.execute(insert_table, list(row.values()))
self._conn.commit()
[docs] def insert_rows(self, table_name, rows):
""" Insert several rows into a table
:param table_name: the table name to insert
:type table_name: str
:param rows: the rows to insert
:type rows: List[Dict[str, object]]
"""
if len(rows) > 0:
insert_table = "INSERT INTO %s VALUES (%s)" % (table_name, ",".join(['?'] * (len(next(iter(rows))))))
self._conn.executemany(insert_table, [list(row.values()) for row in rows])
self._conn.commit()
[docs] def get_update_op(self, update_columns, operator):
""" Get an update operator based on columns and a operator
:param update_columns: a list of columns to update
:type update_columns: List[str]
:param operator: an operator that applies to the columns, including "+", "-", "*", "/", "="
:type operator: str
:return: an operator that suits the backend database
:rtype: str
"""
if operator in "+-*/":
update_ops = []
for update_column in update_columns:
update_ops.append(update_column + "=" + update_column + operator + "?")
return ",".join(update_ops)
elif operator == "=":
update_ops = []
for update_column in update_columns:
update_ops.append(update_column + "=?")
return ",".join(update_ops)
else:
raise NotImplementedError
def _update_update_op(self, row, update_op, update_columns):
update_op_sp = update_op.split('?')
while len(update_op_sp) >= 0 and update_op_sp[-1] == '':
update_op_sp.pop()
assert len(update_op_sp) == len(update_columns)
new_update_op = []
for i in range(len(update_op_sp)):
new_update_op.append(update_op_sp[i])
if isinstance(row[update_columns[i]], str):
new_update_op.append("'" + row[update_columns[i]].replace("'", "''") + "'")
else:
new_update_op.append(str(row[update_columns[i]]))
return ''.join(new_update_op)
[docs] def update_row(self, table_name, row, update_op, update_columns):
""" Update a row that exists in a table
(suggestion: consider to use `update_rows` if you want to update multiple rows)
:param table_name: the table name to update
:type table_name: str
:param row: a new row
:type row: Dict[str, object]
:param update_op: an operator that returned by `get_update_op`
:type update_op: str
:param update_columns: the columns to update
:type update_columns: List[str]
"""
update_table = "UPDATE %s SET %s WHERE _id=?" % (table_name, update_op)
self._conn.execute(update_table, [row[k] for k in update_columns] + [row["_id"]])
self._conn.commit()
[docs] def update_rows(self, table_name, rows, update_ops, update_columns):
""" Update rows that exist in a table
:param table_name: the table name to update
:type table_name: str
:param rows: new rows
:type rows: List[Dict[str, object]]
:param update_ops: operator(s) that returned by `get_update_op`
:type update_ops: Union[List[str], str]
:param update_columns: the columns to update
:type update_columns: List[str]
"""
if len(rows) > 0:
if isinstance(update_ops, (tuple, list)): # +-*/
assert len(rows) == len(update_ops)
# group rows by op to speed up
update_op_collections = defaultdict(list) # key: _update_update_op
for i, row in enumerate(rows):
# self.update_row(row, table_name, update_ops[i], update_columns)
new_update_op = self._update_update_op(row, update_ops[i], update_columns)
update_op_collections[new_update_op].append(row)
for new_update_op, op_rows in update_op_collections.items():
_ids = [row["_id"] for row in op_rows]
for idx in range(0, len(_ids), self.chunksize):
update_table = "UPDATE %s SET %s WHERE _id IN ('%s');" % (
table_name, new_update_op, "','".join(_ids[idx:idx + self.chunksize])
)
self._conn.execute(update_table)
else: # =
update_op = update_ops
# group rows by new values to speed up
value_collections = defaultdict(list) # key: values of new values
for row in rows:
# self.update_row(row, table_name, update_op, update_columns)
value_collections[json.dumps([row[k] for k in update_columns])].append(row)
for new_update_op, op_rows in value_collections.items():
new_update_op = self._update_update_op(op_rows[0], update_op, update_columns)
_ids = [row["_id"] for row in op_rows]
for idx in range(0, len(_ids), self.chunksize):
update_table = "UPDATE %s SET %s WHERE _id IN ('%s');" % (
table_name, new_update_op, "','".join(_ids[idx:idx + self.chunksize])
)
self._conn.execute(update_table)
self._conn.commit()
[docs] def get_rows_by_keys(self, table_name, bys, keys, columns, order_bys=None, reverse=False, top_n=None):
""" Retrieve rows by specific keys in some order
:param table_name: the table name to retrieve
:type table_name: str
:param bys: the given columns to match
:type bys: List[str]
:param keys: the given values to match
:type keys: List[str]
:param columns: the given columns to retrieve
:type columns: List[str]
:param order_bys: the columns whose value are used to sort rows
:type order_bys: List[str]
:param reverse: whether to sort in a reversed order
:type reverse: bool
:param top_n: how many rows to return, default `None` for all rows
:type top_n: int
:return: retrieved rows
:rtype: List[Dict[str, object]]
"""
key_match_events = []
select_table = "SELECT %s FROM %s WHERE %s" % (
",".join(columns), table_name, " AND ".join(["%s=?" % (by) for by in bys])
)
if order_bys:
select_table += " ORDER BY %s %s" % (",".join(order_bys), "DESC" if reverse else "ASC")
if top_n:
select_table += " LIMIT %d" % (top_n)
select_table += ";"
for x in self._conn.execute(select_table, keys):
key_match_event = OrderedDict(zip(columns, x))
key_match_events.append(key_match_event)
return key_match_events
[docs]class MongoDBConnection(BaseDBConnection):
""" KG connection for MongoDB
"""
def __init__(self, db_path, chunksize):
""" Create an connection to SQLite database
:param db_path: database path, e.g., mongodb://localhost:27017/ASER
:type db_path: str
:param chunksize: the chunksize to load/write database
:type chunksize: int
"""
import pymongo
super(MongoDBConnection, self).__init__(db_path, chunksize)
host_port, db_name = os.path.split(db_path)
self._client = pymongo.MongoClient(host_port, document_class=OrderedDict)
self._conn = self._client[db_name]
[docs] def close(self):
""" Close the connection safely
"""
self._client.close()
[docs] def create_table(self, table_name):
""" Create a table without the necessary to provide column information
:param table_name: the table name to create
:type table_name: str
"""
self._conn[table_name]
def __get_projection(self, columns):
projection = {"_id": 0}
for k in columns:
projection[k] = 1
return projection
[docs] def get_columns(self, table_name, columns):
""" Get column information from a table
:param table_name: the table name to retrieve
:type table_name: str
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
projection = self.__get_projection(columns)
results = list(self._conn[table_name].find({}, projection))
return results
[docs] def select_row(self, table_name, _id, columns):
""" Select a row from a table
(suggestion: consider to use `select_rows` if you want to retrieve multiple rows)
:param table_name: the table name to retrieve
:type table_name: str
:param _id: the row id
:type _id: str
:param columns: the columns to retrieve
:type columns: List[str]
:return: a retrieved row
:rtype: Dict[str, object]
"""
projection = self.__get_projection(columns)
return self._conn[table_name].find_one({"_id": _id}, projection)
[docs] def select_rows(self, table_name, _ids, columns):
""" Select rows from a table
:param table_name: the table name to retrieve
:type table_name: str
:param _ids: the row ids
:type _ids: List[str]
:param columns: the columns to retrieve
:type columns: List[str]
:return: retrieved rows
:rtype: List[Dict[str, object]]
"""
table = self._conn[table_name]
exact_match_rows = []
projection = self.__get_projection(columns)
for idx in range(0, len(_ids), self.chunksize):
query = {"_id": {'$in': _ids[idx:idx + self.chunksize]}}
exact_match_rows.extend(table.find(query, projection))
row_cache = {x["_id"]: x for x in exact_match_rows}
exact_match_rows = [row_cache.get(_id, None) for _id in _ids]
return exact_match_rows
[docs] def insert_row(self, table_name, row):
""" Insert a row into a table
(suggestion: consider to use `insert_rows` if you want to insert multiple rows)
:param table_name: the table name to insert
:type table_name: str
:param row: the row to insert
:type row: Dict[str, object]
"""
self._conn[table_name].insert_one(row)
[docs] def insert_rows(self, table_name, rows):
""" Insert several rows into a table
:param table_name: the table name to insert
:type table_name: str
:param rows: the rows to insert
:type rows: List[Dict[str, object]]
"""
self._conn[table_name].insert_many(rows)
[docs] def get_update_op(self, update_columns, operator):
""" Get an update operator based on columns and a operator
:param update_columns: a list of columns to update
:type update_columns: List[str]
:param operator: an operator that applies to the columns, including "+", "-", "*", "/", "="
:type operator: str
:return: an operator that suits the backend database
:rtype: Dict[str, Dict[str, float]]
"""
if operator == "+":
update_ops = {}
for update_column in update_columns:
update_ops[update_column] = 1 # placeholder
return {"$inc": update_ops}
elif operator == "-":
update_ops = {}
for update_column in update_columns:
update_ops[update_column] = -1 # placeholder
return {"$inc": update_ops}
elif operator == "*":
update_ops = {}
for update_column in update_columns:
update_ops[update_column] = 2 # placeholder
return {"$mul": update_ops}
elif operator == "/":
update_ops = {}
for update_column in update_columns:
update_ops[update_column] = 0.5 # placeholder
return {"$mul": update_ops}
elif operator == "=":
update_ops = {}
for update_column in update_columns:
update_ops[update_column] = 1 # placeholder
return {"$set": update_ops}
else:
raise NotImplementedError
def _update_update_op(self, row, update_op, update_columns):
""" Update the operator for a single row
:param row: a new row
:type row: Dict[str, object]
:param update_op: an operator that returned by `get_update_op`
:type update_op: Dict[str, Dict[str, float]]
:param update_columns: the columns to update
:type update_columns: List[str]
:return: Dict[str, Dict[str, float]]
:rtype: Dict[str, Dict[str, float]]
"""
new_update_op = update_op.copy()
for k, v in new_update_op.items():
if k == "$inc":
for update_column in update_columns:
if v[update_column] == 1:
v[update_column] = row[update_column]
else:
v[update_column] = -row[update_column]
elif k == "$mul":
for update_column in update_columns:
if v[update_column] == 2:
v[update_column] = row[update_column]
else:
v[update_column] = 1.0 / row[update_column]
elif k == "$set":
for update_column in update_columns:
v[update_column] = row[update_column]
return new_update_op
[docs] def update_row(self, table_name, row, update_op, update_columns):
""" Update a row that exists in a table
(suggestion: consider to use `update_rows` if you want to update multiple rows)
:param table_name: the table name to update
:type table_name: str
:param row: a new row
:type row: Dict[str, object]
:param update_op: an operator that returned by `get_update_op`
:type update_op: Dict[str, Dict[str, float]]
:param update_columns: the columns to update
:type update_columns: List[str]
"""
self._conn[table_name].update_one({"_id": row["_id"]}, self._update_update_op(row, update_op, update_columns))
[docs] def update_rows(self, table_name, rows, update_ops, update_columns):
""" Update rows that exist in a table
:param table_name: the table name to update
:type table_name: str
:param rows: new rows
:type rows: List[Dict[str, object]]
:param update_ops: operator(s) that returned by `get_update_op`
:type update_ops: Union[List[Dict[str, Dict[str, float]]], Dict[str, Dict[str, float]]]
:param update_columns: the columns to update
:type update_columns: List[str]
"""
if len(rows) > 0:
if isinstance(update_ops, (tuple, list)): # +-*/
assert len(rows) == len(update_ops)
update_op_collections = defaultdict(list)
for i, row in enumerate(rows):
# self.update_row(row, table_name, update_ops[i], update_columns)
new_update_op = self._update_update_op(row, update_ops[i], update_columns)
update_op_collections[json.dumps(new_update_op)].append(row)
for new_update_op, op_rows in update_op_collections.items():
new_update_op = json.loads(new_update_op)
_ids = [row["_id"] for row in op_rows]
for idx in range(0, len(_ids), self.chunksize):
query = {"_id": {'$in': _ids[idx:idx + self.chunksize]}}
self._conn[table_name].update_many(query, new_update_op)
else: # =
update_op = update_ops
value_collections = defaultdict(list)
for row in rows:
value_collections[json.dumps([row[k] for k in update_columns])].append(row)
for new_update_op, op_rows in value_collections.items():
new_update_op = self._update_update_op(op_rows[0], update_op, update_columns)
_ids = [row["_id"] for row in op_rows]
for idx in range(0, len(_ids), self.chunksize):
query = {"_id": {'$in': _ids[idx:idx + self.chunksize]}}
self._conn[table_name].update_many(query, new_update_op)
[docs] def get_rows_by_keys(self, table_name, bys, keys, columns, order_bys=None, reverse=False, top_n=None):
""" Retrieve rows by specific keys in some order
:param table_name: the table name to retrieve
:type table_name: str
:param bys: the given columns to match
:type bys: List[str]
:param keys: the given values to match
:type keys: List[str]
:param columns: the given columns to retrieve
:type columns: List[str]
:param order_bys: the columns whose value are used to sort rows
:type order_bys: List[str]
:param reverse: whether to sort in a reversed order
:type reverse: bool
:param top_n: how many rows to return, default `None` for all rows
:type top_n: int
:return: retrieved rows
:rtype: List[Dict[str, object]]
"""
query = OrderedDict(zip(bys, keys))
projection = self.__get_projection(columns)
cursor = self._conn[table_name].find(query, projection)
if order_bys:
direction = -1 if reverse else 1
cursor = cursor.sort([(k, direction) for k in order_bys])
if top_n:
result = []
for x in cursor:
result.append(x)
if len(result) >= top_n:
break
return result
else:
return list(cursor)