import heapq
import operator
import random
from collections import OrderedDict
from ..concept import ASERConcept, ASERConceptInstancePair
from ..eventuality import Eventuality
from ..relation import Relation, relation_senses
from ..database.db_connection import SqliteDBConnection, MongoDBConnection
from ..database.utils import compute_overlap
CHUNKSIZE = 32768
EVENTUALITY_TABLE_NAME = "Eventualities"
EVENTUALITY_COLUMNS = ["_id", "frequency", "pattern", "verbs", "skeleton_words", "words", "info"]
EVENTUALITY_COLUMN_TYPES = ["PRIMARY KEY", "REAL", "TEXT", "TEXT", "TEXT", "TEXT", "BLOB"]
CONCEPT_TABLE_NAME = "Concepts"
CONCEPT_COLUMNS = ["_id", "pattern", "info"]
CONCEPT_COLUMN_TYPES = ["PRIMARY KEY", "TEXT", "BLOB"]
RELATION_TABLE_NAME = "Relations"
RELATION_COLUMNS = ["_id", "hid", "tid"] + relation_senses
RELATION_COLUMN_TYPES = ["PRIMARY KEY", "TEXT", "TEXT"] + ["REAL"] * len(relation_senses)
CONCEPTINSTANCEPAIR_TABLE_NAME = "ConceptInstancePairs"
CONCEPTINSTANCEPAIR_COLUMNS = ["_id", "cid", "eid", "pattern", "score"]
CONCEPTINSTANCEPAIR_COLUMN_TYPES = ["PRIMARY KEY", "TEXT", "TEXT", "TEXT", "REAL"]
[docs]class ASERKGConnection(object):
""" KG connection for ASER (including eventualities and relations)
"""
def __init__(self, db_path, db="sqlite", mode="cache", grain=None, chunksize=CHUNKSIZE):
"""
:param db_path: database path
:type db_path: str
:param db: the backend database, e.g., "sqlite" or "mongodb"
:type db: str (default = "sqlite")
:param mode: the mode to use the connection.
"insert": this connection is only used to insert/update rows;
"cache": this connection caches some contents that have been retrieved;
"memory": this connection loads all contents in memory;
:type mode: str (default = "cache")
:param grain: the grain to build cache
"words": cache is built on "verbs", "skeleton_words", and "words"
"skeleton_words": cache is built on "verbs", and "skeleton_words"
"verbs": cache is built on "verbs"
None: no cache
:type grain: Union[str, None] (default = None)
:param chunksize: the chunksize to load/write database
:type chunksize: int (default = 32768)
"""
if db == "sqlite":
self._conn = SqliteDBConnection(db_path, chunksize)
elif db == "mongodb":
self._conn = MongoDBConnection(db_path, chunksize)
else:
raise ValueError("Error: %s database is not supported!" % (db))
self.mode = mode
if self.mode not in ["insert", "cache", "memory"]:
raise ValueError("only support insert/cache/memory modes.")
if grain not in [None, "verbs", "skeleton_words", "words"]:
raise ValueError("Error: only support None/verbs/skeleton_words/words grain.")
self.grain = grain # None, verbs, skeleton_words, words
self.eventuality_table_name = EVENTUALITY_TABLE_NAME
self.eventuality_columns = EVENTUALITY_COLUMNS
self.eventuality_column_types = EVENTUALITY_COLUMN_TYPES
self.relation_table_name = RELATION_TABLE_NAME
self.relation_columns = RELATION_COLUMNS
self.relation_column_types = RELATION_COLUMN_TYPES
self.eids = set()
self.rids = set()
self.eid2eventuality_cache = dict()
self.rid2relation_cache = dict()
if self.grain == "words":
self.partial2eids_cache = {"verbs": dict(), "skeleton_words": dict(), "words": dict()}
elif self.grain == "skeleton_words":
self.partial2eids_cache = {"verbs": dict(), "skeleton_words": dict()}
elif self.grain == "verbs":
self.partial2eids_cache = {"verbs": dict()}
else:
self.partial2eids_cache = dict()
self.partial2rids_cache = {"hid": dict()}
self.init()
[docs] def init(self):
""" Initialize the ASERKGConnection, including creating tables, loading eids and rids, and building cache
"""
for table_name, columns, column_types in zip(
[self.eventuality_table_name, self.relation_table_name], [self.eventuality_columns, self.relation_columns],
[self.eventuality_column_types, self.relation_column_types]
):
if len(columns) == 0 or len(column_types) == 0:
raise ValueError("Error: %s_columns and %s_column_types must be defined" % (table_name, table_name))
try:
self._conn.create_table(table_name, columns, column_types)
except:
pass
if self.mode == "memory":
for e in map(
self._convert_row_to_eventuality,
self._conn.get_columns(self.eventuality_table_name, self.eventuality_columns)
):
self.eids.add(e.eid)
self.eid2eventuality_cache[e.eid] = e
# handle another cache
for k, v in self.partial2eids_cache.items():
if " ".join(getattr(e, k)) not in v:
v[" ".join(getattr(e, k))] = [e.eid]
else:
v[" ".join(getattr(e, k))].append(e.eid)
for r in map(
self._convert_row_to_relation, self._conn.get_columns(self.relation_table_name, self.relation_columns)
):
self.rids.add(r.rid)
self.rid2relation_cache[r.rid] = r
# handle another cache
for k, v in self.partial2rids_cache.items():
if getattr(r, k) not in v:
v[getattr(r, k)] = [r.rid]
else:
v[getattr(r, k)].append(r.rid)
else:
for e in self._conn.get_columns(self.eventuality_table_name, ["_id"]):
self.eids.add(e["_id"])
for r in self._conn.get_columns(self.relation_table_name, ["_id"]):
self.rids.add(r["_id"])
[docs] def close(self):
""" Close the ASERKGConnection safely
"""
self._conn.close()
self.eids.clear()
self.rids.clear()
self.eid2eventuality_cache.clear()
self.rid2relation_cache.clear()
# close another cache
for k in self.partial2eids_cache:
self.partial2eids_cache[k].clear()
for k in self.partial2rids_cache:
self.partial2rids_cache[k].clear()
"""
KG (Eventualities)
"""
def _convert_eventuality_to_row(self, eventuality):
row = OrderedDict({"_id": eventuality.eid})
for c in self.eventuality_columns[1:-1]:
d = getattr(eventuality, c)
if isinstance(d, list):
row[c] = " ".join(d)
else:
row[c] = d
row["info"] = eventuality.encode(minimum=True)
return row
def _convert_row_to_eventuality(self, row):
eventuality = Eventuality().decode(row["info"])
eventuality.eid = row["_id"]
eventuality.frequency = row["frequency"]
eventuality.pattern = row["pattern"]
return eventuality
[docs] def get_eventuality_columns(self, columns):
""" Get column information from eventualities
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
return self._conn.get_columns(self.eventuality_table_name, columns)
def _insert_eventuality(self, eventuality):
row = self._convert_eventuality_to_row(eventuality)
self._conn.insert_row(self.eventuality_table_name, row)
if self.mode == "insert":
self.eids.add(eventuality.eid)
elif self.mode == "cache":
self.eids.add(eventuality.eid)
self.eid2eventuality_cache[eventuality.eid] = eventuality
for k, v in self.partial2eids_cache.items():
if eventuality.get(k) in v:
v[eventuality.get(k)].append(eventuality.eid)
elif self.mode == "memory":
self.eids.add(eventuality.eid)
self.eid2eventuality_cache[eventuality.eid] = eventuality
for k, v in self.partial2eids_cache.items():
if eventuality.get(k) not in v:
v[eventuality.get(k)] = [eventuality.eid]
else:
v[eventuality.get(k)].append(eventuality.eid)
return eventuality
def _insert_eventualities(self, eventualities):
rows = list(map(self._convert_eventuality_to_row, eventualities))
self._conn.insert_rows(self.eventuality_table_name, rows)
if self.mode == "insert":
for eventuality in eventualities:
self.eids.add(eventuality.eid)
elif self.mode == "cache":
for eventuality in eventualities:
self.eids.add(eventuality.eid)
self.eid2eventuality_cache[eventuality.eid] = eventuality
for k, v in self.partial2eids_cache.items():
if eventuality.get(k) in v:
v[eventuality.get(k)].append(eventuality.eid)
elif self.mode == "memory":
for eventuality in eventualities:
self.eids.add(eventuality.eid)
self.eid2eventuality_cache[eventuality.eid] = eventuality
for k, v in self.partial2eids_cache.items():
if eventuality.get(k) not in v:
v[eventuality.get(k)] = [eventuality.eid]
else:
v[eventuality.get(k)].append(eventuality.eid)
return eventualities
def _get_eventuality_and_store_in_cache(self, eid):
return self._get_eventualities_and_store_in_cache([eid])[0]
def _get_eventualities_and_store_in_cache(self, eids):
eventualities = list(
map(
self._convert_row_to_eventuality,
self._conn.select_rows(self.eventuality_table_name, eids, self.eventuality_columns)
)
)
for eventuality in eventualities:
if eventuality:
self.eid2eventuality_cache[eventuality.eid] = eventuality
# It seems not to need to append
# if self.mode == "cache":
# for k, v in self.partial2eids_cache.items():
# if eventuality.get(k) in v:
# v[eventuality.get(k)].append(eventuality.eid)
# elif self.mode == "memory":
# for k, v in self.partial2eids_cache.items():
# if eventuality.get(k) not in v:
# v[eventuality.get(k)] = [eventuality.eid]
# else:
# v[eventuality.get(k)].append(eventuality.eid)
return eventualities
def _update_eventuality(self, eventuality):
# update db
update_op = self._conn.get_update_op(["frequency"], "+")
row = self._convert_eventuality_to_row(eventuality)
self._conn.update_row(self.eventuality_table_name, row, update_op, ["frequency"])
# updata cache
if self.mode == "insert":
return None # don"t care
updated_eventuality = self.eid2eventuality_cache.get(eventuality.eid, None)
if updated_eventuality: # self.mode == "memory" or hit in cache
updated_eventuality.frequency += eventuality.frequency
else: # self.mode == "cache" and miss in cache
updated_eventuality = self._get_eventuality_and_store_in_cache(eventuality.eid)
return updated_eventuality
def _update_eventualities(self, eventualities):
# update db
update_op = self._conn.get_update_op(["frequency"], "+")
rows = list(map(self._convert_eventuality_to_row, eventualities))
self._conn.update_rows(self.eventuality_table_name, rows, update_op, ["frequency"])
# update cache
if self.mode == "insert":
return [None] * len(eventualities) # don"t care
updated_eventualities = []
missed_indices = []
missed_eids = []
for idx, eventuality in enumerate(eventualities):
if eventuality.eid not in self.eids:
updated_eventualities.append(None)
else:
updated_eventuality = self.eid2eventuality_cache.get(eventuality.eid, None)
updated_eventualities.append(updated_eventuality)
if updated_eventuality:
updated_eventuality.frequency += eventuality.frequency
else:
missed_indices.append(idx)
missed_eids.append(eventuality.eid)
for idx, updated_eventuality in enumerate(self._get_eventualities_and_store_in_cache(missed_eids)):
updated_eventualities[missed_indices[idx]] = updated_eventuality
return updated_eventualities
[docs] def insert_eventuality(self, eventuality):
""" Insert/Update an eventuality into ASER
(suggestion: consider to use `insert_eventualities` if you want to insert multiple eventualities)
:param eventuality: an eventuality to insert/update
:type eventuality: aser.eventuality.Eventuality
:return: the inserted/updated eventuality
:rtype: aser.eventuality.Eventuality
"""
if eventuality.eid not in self.eids:
return self._insert_eventuality(eventuality)
else:
return self._update_eventuality(eventuality)
[docs] def insert_eventualities(self, eventualities):
""" Insert/Update eventualities into ASER
:param eventualities: eventualities to insert/update
:type eventualities: List[aser.eventuality.Eventuality]
:return: the inserted/updated eventualities
:rtype: List[aser.eventuality.Eventuality]
"""
results = []
new_eventualities = []
existing_indices = []
existing_eventualities = []
for idx, eventuality in enumerate(eventualities):
if eventuality.eid not in self.eids:
new_eventualities.append(eventuality)
results.append(eventuality)
else:
existing_indices.append(idx)
existing_eventualities.append(eventuality)
results.append(None)
if len(new_eventualities):
self._insert_eventualities(new_eventualities)
if len(existing_eventualities):
for idx, updated_eventuality in enumerate(self._update_eventualities(existing_eventualities)):
results[existing_indices[idx]] = updated_eventuality
return results
[docs] def get_exact_match_eventuality(self, eventuality):
""" Retrieve an exact matched eventuality from ASER
(suggestion: consider to use `get_exact_match_eventualities` if you want to retrieve multiple eventualities)
:param eventuality: an eventuality that contains the eid
:type eventuality: Union[aser.eventuality.Eventuality, Dict[str, object], str]
:return: the exact matched eventuality
:rtype: aser.eventuality.Eventuality
"""
if isinstance(eventuality, Eventuality):
eid = eventuality.eid
elif isinstance(eventuality, dict):
eid = eventuality["eid"]
elif isinstance(eventuality, str):
eid = eventuality
else:
raise ValueError("Error: eventuality should be an instance of Eventuality, a dictionary, or a eid.")
if eid not in self.eids:
return None
exact_match_eventuality = self.eid2eventuality_cache.get(eid, None)
if not exact_match_eventuality:
exact_match_eventuality = self._get_eventuality_and_store_in_cache(eid)
return exact_match_eventuality
[docs] def get_exact_match_eventualities(self, eventualities):
""" Retrieve multiple exact matched eventualities from ASER
:param eventualities: eventualities
:type eventualities: Union[List[aser.eventuality.Eventuality], List[Dict[str, object]], List[str]]
:return: the exact matched eventualities
:rtype: List[aser.eventuality.Eventuality]
"""
exact_match_eventualities = []
if len(eventualities):
if isinstance(eventualities[0], Eventuality):
eids = [eventuality.eid for eventuality in eventualities]
elif isinstance(eventualities[0], dict):
eids = [eventuality["eid"] for eventuality in eventualities]
elif isinstance(eventualities[0], str):
eids = eventualities
else:
raise ValueError("Error: eventualities should instances of Eventuality, dictionaries, or eids.")
missed_indices = []
missed_eids = []
for idx, eid in enumerate(eids):
if eid not in self.eids:
exact_match_eventualities.append(None)
exact_match_eventuality = self.eid2eventuality_cache.get(eid, None)
exact_match_eventualities.append(exact_match_eventuality)
if not exact_match_eventuality:
missed_indices.append(idx)
missed_eids.append(eid)
for idx, exact_match_eventuality in enumerate(self._get_eventualities_and_store_in_cache(missed_eids)):
exact_match_eventualities[missed_indices[idx]] = exact_match_eventuality
return exact_match_eventualities
[docs] def get_eventualities_by_keys(self, bys, keys, order_bys=None, reverse=False, top_n=None):
""" Retrieve multiple partial matched eventualities by keys and values from ASER
:param bys: the given columns to match
:type bys: List[str]
:param keys: the given values to match
:type keys: List[str]
:param order_bys: the columns whose value are used to sort rows
:type order_bys: List[str]
:param reverse: whether to sort in a reversed order
:type reverse: bool
:param top_n: how many eventualities to return, default `None` for all eventualities
:type top_n: int
:return: the partial matched eventualities
:rtype: List[aser.eventuality.Eventuality]
"""
assert len(bys) == len(keys)
for i in range(len(bys) - 1, -1, -1):
if bys[i] not in self.eventuality_columns:
bys.pop(i)
keys.pop(i)
if len(bys) == 0:
return []
cache = None
by_index = -1
for k in ["words", "skeleton_words", "verbs"]:
if k in bys and k in self.partial2eids_cache:
cache = self.partial2eids_cache[k]
by_index = bys.index(k)
break
if cache:
if keys[by_index] in cache:
key_match_eventualities = [self.eid2eventuality_cache[eid] for eid in cache[keys[by_index]]]
else:
if self.mode == "memory":
return []
key_cache = []
key_match_eventualities = list(
map(
self._convert_row_to_eventuality,
self._conn.get_rows_by_keys(
self.eventuality_table_name, [bys[by_index]], [keys[by_index]], self.eventuality_columns
)
)
)
for key_match_eventuality in key_match_eventualities:
if key_match_eventuality.eid not in self.eid2eventuality_cache:
self.eid2eventuality_cache[key_match_eventuality.eid] = key_match_eventuality
key_cache.append(key_match_eventuality.eid)
cache[keys[by_index]] = key_cache
for i in range(len(bys)):
if i == by_index:
continue
key_match_eventualities = list(filter(lambda x: x[bys[i]] == keys[i], key_match_eventualities))
if order_bys:
key_match_eventualities.sort(key=operator.itemgetter(*order_bys), reverse=reverse)
if top_n:
key_match_eventualities = key_match_eventualities[:top_n]
return key_match_eventualities
return list(
map(
self._convert_row_to_eventuality,
self._conn.get_rows_by_keys(
self.eventuality_table_name,
bys,
keys,
self.eventuality_columns,
order_bys=order_bys,
reverse=reverse,
top_n=top_n
)
)
)
[docs] def get_partial_match_eventualities(self, eventuality, bys, top_n=None, threshold=0.8, sort=True):
""" Retrieve multiple partial matched eventualities by a given eventuality and properties from ASER
:param eventuality: the given eventuality to match
:type eventuality: aser.eventuality.Eventuality
:param bys: the given properties to match
:type bys: List[str]
:param top_n: how many rows to return, default `None` for all rows
:type top_n: int
:param threshold: the minimum similarity
:type threshold: float (default = 0.8)
:param sort: whether to sort
:type sort: bool (default = True)
:return: the partial matched eventualities
:rtype: List[aser.eventuality.Eventuality]
"""
assert self.grain is not None
# exact match by skeleton_words, skeleton_words_clean or verbs, and compute similarity according type
for by in bys:
key_match_eventualities = self.get_eventualities_by_keys([by], [" ".join(getattr(eventuality, by))])
if len(key_match_eventualities) == 0:
continue
if not sort:
if top_n and len(key_match_eventualities) > top_n:
return random.sample(key_match_eventualities, top_n)
else:
return key_match_eventualities
# sort by (similarity, frequency, idx)
queue = []
queue_len = 0
for idx, key_match_eventuality in enumerate(key_match_eventualities):
similarity = compute_overlap(
getattr(eventuality, self.grain), getattr(key_match_eventuality, self.grain)
)
if similarity >= threshold:
if not top_n or queue_len < top_n:
heapq.heappush(queue, (similarity, key_match_eventuality.frequency, idx, key_match_eventuality))
queue_len += 1
else:
heapq.heappushpop(
queue, (similarity, key_match_eventuality.frequency, idx, key_match_eventuality)
)
key_match_results = []
while len(queue) > 0:
x = heapq.heappop(queue)
key_match_results.append((x[0], x[-1]))
key_match_results.reverse()
return key_match_results
return []
"""
KG (Relations)
"""
def _convert_relation_to_row(self, relation):
row = OrderedDict({"_id": relation.rid})
for c in self.relation_columns[1:-len(relation_senses)]:
row[c] = getattr(relation, c)
for r in relation_senses:
row[r] = relation.relations.get(r, 0.0)
return row
def _convert_row_to_relation(self, row):
return Relation(
row["hid"], row["tid"], {r: cnt
for r, cnt in row.items() if isinstance(cnt, float) and cnt > 0.0}
)
[docs] def get_relation_columns(self, columns):
""" Get column information from relations
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
return self._conn.get_columns(self.relation_table_name, columns)
def _insert_relation(self, relation):
row = self._convert_relation_to_row(relation)
self._conn.insert_row(self.relation_table_name, row)
if self.mode == "insert":
self.rids.add(relation.rid)
elif self.mode == "cache":
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if relation.get(k) in v:
v[relation.get(k)].append(relation.rid)
elif self.mode == "memory":
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if relation.get(k) not in v:
v[relation.get(k)] = [relation.rid]
else:
v[relation.get(k)].append(relation.rid)
return relation
def _insert_relations(self, relations):
rows = list(map(self._convert_relation_to_row, relations))
self._conn.insert_rows(self.relation_table_name, rows)
if self.mode == "insert":
for relation in relations:
self.rids.add(relation.rid)
elif self.mode == "cache":
for relation in relations:
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if relation.get(k) in v:
v[relation.get(k)].append(relation.rid)
elif self.mode == "memory":
for relation in relations:
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if relation.get(k) not in v:
v[relation.get(k)] = [relation.rid]
else:
v[relation.get(k)].append(relation.rid)
return relations
def _get_relation_and_store_in_cache(self, rid):
return self._get_relations_and_store_in_cache([rid])[0]
def _get_relations_and_store_in_cache(self, rids):
relations = list(
map(
self._convert_row_to_relation,
self._conn.select_rows(self.relation_table_name, rids, self.relation_columns)
)
)
for relation in relations:
if relation:
self.rid2relation_cache[relation.rid] = relation
return relations
def _update_relation(self, relation):
# find new relation frequencies
update_columns = []
for r in relation_senses:
if relation.relations.get(r, 0.0) > 0.0:
update_columns.append(r)
# update db
update_op = self._conn.get_update_op(update_columns, "+")
row = self._convert_relation_to_row(relation)
self._conn.update_row(self.relation_table_name, row, update_op, update_columns)
# update cache
updated_relation = self.rid2relation_cache.get(relation.rid, None)
if updated_relation:
for r in update_columns:
updated_relation.relation[r] += relation.relation[r]
else:
updated_relation = self._get_relation_and_store_in_cache(relation.rid)
return updated_relation
def _update_relations(self, relations):
# update db
update_op = self._conn.get_update_op(relation_senses, "+")
rows = list(map(self._convert_relation_to_row, relations))
self._conn.update_rows(self.relation_table_name, rows, update_op, relation_senses)
# update cache
updated_relations = []
missed_indices = []
missed_rids = []
for idx, relation in enumerate(relations):
if relation.rid not in self.rids:
updated_relations.append(None)
else:
updated_relation = self.rid2relation_cache.get(relation.rid, None)
updated_relations.append(updated_relations)
if updated_relation:
for r in relation_senses:
if updated_relation.relations.get(r, 0.0) > 0.0:
updated_relation.relations[r] += relation.relations[r]
else:
missed_indices.append(idx)
missed_rids.append(relation.rid)
for idx, updated_relation in enumerate(self._get_relations_and_store_in_cache(missed_rids)):
updated_relations[missed_indices[idx]] = updated_relation
return updated_relations
[docs] def insert_relation(self, relation):
""" Insert/Update a relation into ASER
(suggestion: consider to use `insert_relations` if you want to insert multiple relations)
:param relation: a relation to insert/update
:type relation: aser.relation.Relation
:return: the inserted/updated relation
:rtype: aser.relation.Relation
"""
if relation.rid not in self.rid2relation_cache:
return self._insert_relation(relation)
else:
return self._update_relation(relation)
[docs] def insert_relations(self, relations):
""" Insert/Update relations into ASER
:param relations: relations to insert/update
:type relations: List[aser.relation.Relation]
:return: the inserted/updated relations
:rtype: List[aser.relation.Relation]
"""
results = []
new_relations = []
existing_indices = []
existing_relations = []
for idx, relation in enumerate(relations):
if relation.rid not in self.rids:
new_relations.append(relation)
results.append(relation)
else:
existing_indices.append(idx)
existing_relations.append(relation)
results.append(None)
if len(new_relations):
self._insert_relations(new_relations)
if len(existing_relations):
for idx, updated_relation in enumerate(self._update_relations(existing_relations)):
results[existing_indices[idx]] = updated_relation
return results
[docs] def get_exact_match_relation(self, relation):
""" Retrieve an exact matched relation from ASER
(suggestion: consider to use `get_exact_match_relations` if you want to retrieve multiple relations)
:param relation: a relation that contains the rid or an eventuality pair that contains two eids
:type relation: Union[aser.relation.Relation, Dict[str, object], str, Tuple[aser.eventuality.Eventuality, aser.eventuality.Eventuality], Tuple[str, str]]
:return: the exact matched relation
:rtype: aser.relation.Relation
"""
if isinstance(relation, Relation):
rid = relation.rid
elif isinstance(relation, dict):
rid = relation["rid"]
elif isinstance(relation, str):
rid = relation
elif isinstance(relation, (tuple, list)) and len(relation) == 2:
if isinstance(relation[0], Eventuality) and isinstance(relation[1], Eventuality):
rid = Relation.generate_rid(relation[0].eid, relation[1].eid)
elif isinstance(relation[0], str) and isinstance(relation[1], str):
rid = Relation.generate_rid(relation[0], relation[1])
else:
raise ValueError(
"Error: relation should be (an instance of Eventuality, an instance of Eventuality) or (hid, tid)."
)
else:
raise ValueError(
"Error: relation should be an instance of Relation, a dictionary, rid,"
"(an instance of Eventuality, an instance of Eventuality), or (hid, tid)."
)
if rid not in self.rids:
return None
exact_match_relation = self.rid2relation_cache.get(rid, None)
if not exact_match_relation:
exact_match_relation = self._get_relation_and_store_in_cache(rid)
return exact_match_relation
[docs] def get_exact_match_relations(self, relations):
""" Retrieve exact matched relations from ASER
:param relations: a relations that contain the rids or eventuality pairs each of which contains two eids
:type relations: Union[List[aser.relation.Relation], List[Dict[str, object]], List[str], List[Tuple[aser.eventuality.Eventuality, aser.eventuality.Eventuality]], List[Tuple[str, str]]]
:return: the exact matched relations
:rtype: List[aser.relation.Relation]
"""
exact_match_relations = []
if len(relations):
if isinstance(relations[0], Relation):
rids = [relation.rid for relation in relations]
elif isinstance(relations[0], dict):
rids = [relation["rid"] for relation in relations]
elif isinstance(relations[0], str):
rids = relations
elif isinstance(relations[0], (tuple, list)) and len(relations[0]) == 2:
if isinstance(relations[0][0], Eventuality) and isinstance(relations[0][1], Eventuality):
rids = [Relation.generate_rid(relation[0].eid, relation[1].eid) for relation in relations]
elif isinstance(relations[0][0], str) and isinstance(relations[0][1], str):
rids = [Relation.generate_rid(relation[0], relation[1]) for relation in relations]
else:
raise ValueError(
"Error: relations should be [(an instance of Eventuality, an instance of Eventuality), ...] or [(hid, tid), ...]."
)
else:
raise ValueError(
"Error: relations should be instances of Relation, dictionaries, rids, [(an instance of Eventuality, an instance of Eventuality), ...], or [(hid, tid), ...]."
)
missed_indices = []
missed_rids = []
for idx, rid in enumerate(rids):
if rid not in self.rids:
exact_match_relations.append(None)
exact_match_relation = self.rid2relation_cache.get(rid, None)
exact_match_relations.append(exact_match_relation)
if not exact_match_relation:
missed_indices.append(idx)
missed_rids.append(rid)
for idx, exact_match_relation in enumerate(self._get_relations_and_store_in_cache(missed_rids)):
exact_match_relations[missed_indices[idx]] = exact_match_relation
return exact_match_relations
[docs] def get_relations_by_keys(self, bys, keys, order_bys=None, reverse=False, top_n=None):
""" Retrieve multiple partial matched relations by keys and values from ASER
:param bys: the given columns to match
:type bys: List[str]
:param keys: the given values to match
:type keys: List[str]
:param order_bys: the columns whose value are used to sort rows
:type order_bys: Union[List[str], None] (default = None)
:param reverse: whether to sort in a reversed order
:type reverse: bool (default = False)
:param top_n: how many relations to return, default `None` for all relations
:type top_n: Union[int, None] (default = None)
:return: the partial matched relations
:rtype: List[aser.relation.Relation]
"""
assert len(bys) == len(keys)
for i in range(len(bys) - 1, -1, -1):
if bys[i] not in self.relation_columns:
bys.pop(i)
keys.pop(i)
if len(bys) == 0:
return []
cache = None
by_index = -1
for k in ["hid", "tid"]:
if k in bys and k in self.partial2rids_cache:
cache = self.partial2rids_cache[k]
by_index = bys.index(k)
break
if cache:
if keys[by_index] in cache:
key_match_relations = [self.rid2relation_cache[rid] for rid in cache[keys[by_index]]]
else:
if self.mode == "memory":
return []
key_cache = []
key_match_relations = list(
map(
self._convert_row_to_relation,
self._conn.get_rows_by_keys(
self.relation_table_name, [bys[by_index]], [keys[by_index]], self.relation_columns
)
)
)
for key_match_relation in key_match_relations:
if key_match_relation.rid not in self.rid2relation_cache:
self.rid2relation_cache[key_match_relation.rid] = key_match_relation
key_cache.append(key_match_relation.rid)
cache[keys[by_index]] = key_cache
for i in range(len(bys)):
if i == by_index:
continue
key_match_relations = list(filter(lambda x: x[bys[i]] == keys[i], key_match_relations))
if order_bys:
key_match_relations.sort(key=operator.itemgetter(*order_bys), reverse=reverse)
if top_n:
key_match_relations = key_match_relations[:top_n]
return key_match_relations
return list(
map(
self._convert_row_to_relation,
self._conn.get_rows_by_keys(
self.relation_table_name,
bys,
keys,
self.relation_columns,
order_bys=order_bys,
reverse=reverse,
top_n=top_n
)
)
)
"""
Additional APIs
"""
[docs]class ASERConceptConnection(object):
""" Concept connection for ASER (including concepts, concept_instance_pairs, and relations)
"""
def __init__(self, db_path, db="sqlite", mode='cache', chunksize=CHUNKSIZE):
"""
:param db_path: database path
:type db_path: str
:param db: the backend database, e.g., "sqlite" or "mongodb"
:type db: str (default = sqlite)
:param mode: the mode to use the connection.
"insert": this connection is only used to insert/update rows;
"cache": this connection caches some contents that have been retrieved;
"memory": this connection loads all contents in memory;
:type mode: str (default = "cache")
:param chunksize: the chunksize to load/write database
:type chunksize: int (default = 32768)
"""
if db == "sqlite":
self._conn = SqliteDBConnection(db_path, chunksize)
elif db == "mongodb":
self._conn = MongoDBConnection(db_path, chunksize)
else:
raise NotImplementedError("Error: %s database is not supported!" % (db))
self.mode = mode
if self.mode not in ["insert", "cache", "memory"]:
raise NotImplementedError("Error: only support insert/cache/memory modes.")
self.concept_table_name = CONCEPT_TABLE_NAME
self.concept_columns = CONCEPT_COLUMNS
self.concept_column_types = CONCEPT_COLUMN_TYPES
self.concept_instance_pair_table_name = CONCEPTINSTANCEPAIR_TABLE_NAME
self.concept_instance_pair_columns = CONCEPTINSTANCEPAIR_COLUMNS
self.concept_instance_pair_column_types = CONCEPTINSTANCEPAIR_COLUMN_TYPES
self.relation_table_name = RELATION_TABLE_NAME
self.relation_columns = RELATION_COLUMNS
self.relation_column_types = RELATION_COLUMN_TYPES
self.cids = set()
self.eids = set()
self.rids = set()
self.cid2concept_cache = dict()
self.cid2eid_pattern_scores = dict()
self.rid2relation_cache = dict()
self.eid2cid_scores = dict()
self.partial2cids_cache = dict()
self.partial2rids_cache = {"hid": dict()}
self.init()
[docs] def init(self):
""" Initialize the ASERConceptConnection, including creating tables, loading cids, eids, rids, and building cache
"""
for table_name, columns, column_types in zip(
[self.concept_table_name, self.concept_instance_pair_table_name, self.relation_table_name],
[self.concept_columns, self.concept_instance_pair_columns, self.relation_columns],
[self.concept_column_types, self.concept_instance_pair_column_types, self.relation_column_types]
):
if len(columns) == 0 or len(column_types) == 0:
raise NotImplementedError(
"Error: %s_columns and %s_column_types must be defined" % (table_name, table_name)
)
try:
self._conn.create_table(table_name, columns, column_types)
except:
pass
if self.mode == 'memory':
for c in map(
self._convert_row_to_concept, self._conn.get_columns(self.concept_table_name, self.concept_columns)
):
self.cids.add(c.cid)
self.cid2concept_cache[c.cid] = c
# handle another cache
for k, v in self.partial2cids_cache.items():
if getattr(c, k) not in v:
v[getattr(c, k)] = [c.cid]
else:
v[getattr(c, k)].append(c.cid)
for p in map(
self._convert_row_to_concept_instance_pair,
self._conn.get_columns(self.concept_instance_pair_table_name, self.concept_instance_pair_columns)
):
self.eids.add(p.eid)
# handle another cache
if p.cid not in self.cid2eid_pattern_scores:
self.cid2eid_pattern_scores[p.cid] = [(p.eid, p.pattern, p.score)]
else:
self.cid2eid_pattern_scores[p.cid].append((p.eid, p.pattern, p.score))
if p.eid not in self.eid2cid_scores:
self.eid2cid_scores[p.eid] = [(p.cid, p.score)]
else:
self.eid2cid_scores[p.eid].append((p.cid, p.score))
for r in map(
self._convert_row_to_relation, self._conn.get_columns(self.relation_table_name, self.relation_columns)
):
self.rids.add(r.rid)
self.rid2relation_cache[r.rid] = r
# handle another cache
for k, v in self.partial2rids_cache.items():
if getattr(r, k) not in v:
v[getattr(r, k)] = [r.rid]
else:
v[getattr(r, k)].append(r.rid)
else:
for x in self._conn.get_columns(self.concept_table_name, ["_id"]):
self.cids.add(x["_id"])
for x in self._conn.get_columns(self.concept_instance_pair_table_name, ["eid"]):
self.eids.add(x["eid"])
for x in self._conn.get_columns(self.relation_table_name, ["_id"]):
self.rids.add(x["_id"])
[docs] def close(self):
""" Close the ASERConceptConnection safely
"""
self._conn.close()
self.cids.clear()
self.eids.clear()
self.rids.clear()
self.cid2concept_cache.clear()
self.cid2eid_pattern_scores.clear()
self.eid2cid_scores.clear()
self.rid2relation_cache.clear()
for k in self.partial2cids_cache:
self.partial2cids_cache[k].clear()
for k in self.partial2rids_cache:
self.partial2rids_cache[k].clear()
"""
KG (Concepts)
"""
def _convert_concept_to_row(self, concept):
row = OrderedDict({"_id": concept.cid})
for c in self.concept_columns[1:-1]:
d = getattr(concept, c)
if isinstance(d, list):
row[c] = " ".join(d)
else:
row[c] = d
row["info"] = concept.encode()
return row
def _convert_row_to_concept(self, row):
concept = ASERConcept().decode(row["info"])
concept.cid = row["_id"]
return concept
[docs] def get_concept_columns(self, columns):
""" Get column information from concepts
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
return self._conn.get_columns(self.concept_table_name, columns)
def _insert_concept(self, concept):
row = self._convert_concept_to_row(concept)
self._conn.insert_row(self.concept_table_name, row)
if self.mode == "insert":
self.cids.add(concept.cid)
elif self.mode == "cache":
self.cids.add(concept.cid)
self.cid2concept_cache[concept.cid] = concept
for k, v in self.partial2cids_cache.items():
if concept.get(k) not in v:
v[concept.get(k)] = [concept.cid]
else:
v[concept.get(k)].append(concept.cid)
return concept
def _insert_concepts(self, concepts):
rows = list(map(self._convert_concept_to_row, concepts))
self._conn.insert_rows(self.concept_table_name, rows)
if self.mode == "insert":
for concept in concepts:
self.cids.add(concept.cid)
elif self.mode == "cache":
for concept in concepts:
self.cids.add(concept.cid)
self.cid2concept_cache[concept.cid] = concept
for k, v in self.partial2cids_cache.items():
if concept.get(k) in v:
v[concept.get(k)].append(concept.cid)
elif self.mode == "memory":
for concept in concepts:
self.cids.add(concept.cid)
self.cid2concept_cache[concept.cid] = concept
for k, v in self.partial2cids_cache.items():
if concept.get(k) not in v:
v[concept.get(k)] = [concept.cid]
else:
v[concept.get(k)].append(concept.cid)
return concepts
def _get_concept_and_store_in_cache(self, cid):
return self._get_concepts_and_store_in_cache([cid])[0]
def _get_concepts_and_store_in_cache(self, cids):
concepts = list(
map(
self._convert_row_to_concept,
self._conn.select_rows(self.concept_table_name, cids, self.concept_columns)
)
)
for concept in concepts:
if concept:
self.cid2concept_cache[concept.cid] = concept
cached_eid_pattern_scores = self.cid2eid_pattern_scores.get(concept.cid, None)
if not cached_eid_pattern_scores:
eid_pattern_scores = self._conn.get_rows_by_keys(
self.concept_instance_pair_table_name,
bys=["cid"],
keys=[concept.cid],
columns=["eid", "pattern", "score"]
)
cached_eid_pattern_scores = [(x["eid"], x["pattern"], x["score"]) for x in eid_pattern_scores]
self.cid2eid_pattern_scores[concept.cid] = cached_eid_pattern_scores
concept.instances = cached_eid_pattern_scores
return concepts
def _update_concept(self, concept):
# append/update new instances
updated_concept = self.cid2concept_cache.get(concept.cid, None)
if not updated_concept: # self.mode == "memory" or hit in cache
if self.mode == "insert":
updated_concept = self._convert_row_to_concept(
self._conn.select_row(self.concept_table_name, concept.cid, self.concept_columns)
)
else:
updated_concept = self._get_concept_and_store_in_cache(concept.cid)
for x in concept.instances:
matched = False
for y in updated_concept.instances:
if y[0] == x[0]:
y[2] += x[2]
matched = True
break
if not matched:
updated_concept.instances.append(x)
update_op = self._conn.get_update_op(["info"], "=")
row = self._convert_concept_to_row(updated_concept)
self._conn.update_row(self.concept_table_name, row, update_op, ["info"])
if self.mode == "insert":
return None # don"t care
else:
return updated_concept
def _update_concepts(self, concepts):
updated_concepts = []
missed_indices = []
missed_cids = []
for idx, concept in enumerate(concepts):
if concept.cid not in self.cids:
updated_concepts.append(None)
else:
updated_concept = self.cid2concept_cache.get(concept.cid, None)
updated_concepts.append(updated_concept)
if not updated_concept:
missed_indices.append(idx)
missed_cids.append(concept.cid)
if self.mode == "insert":
for idx, updated_concept in enumerate(
map(
self._convert_row_to_concept,
self._conn.select_rows(self.concept_table_name, missed_cids, self.concept_columns)
)
):
updated_concepts[missed_indices[idx]] = updated_concept
else:
for idx, updated_concept in enumerate(self._get_concepts_and_store_in_cache(missed_cids)):
updated_concepts[missed_indices[idx]] = updated_concept
for idx, concept in enumerate(concepts):
if not updated_concepts[idx]:
updated_concepts[idx] = concept
else:
updated_concept = updated_concepts[idx]
for x in concept.instances:
matched = False
for y in updated_concept.instances:
if y[0] == x[0]:
y[2] += x[2]
matched = True
break
if not matched:
updated_concept.instances.append(x)
update_op = self._conn.get_update_op(["info"], "=")
rows = list(map(self._convert_concept_to_row, updated_concepts))
self._conn.update_rows(self.concept_table_name, rows, update_op, ["info"])
if self.mode == "insert":
return [None] * len(concepts) # don"t care
return updated_concepts
[docs] def insert_concept(self, concept):
""" Insert/Update a concept into ASER
(suggestion: consider to use `insert_concepts` if you want to insert multiple concepts)
:param concept: a concept to insert/update
:type concept: aser.concept.ASERConcept
:return: the inserted/updated concept
:rtype: aser.concept.ASERConcept
"""
if concept.cid not in self.cids:
concept = self._insert_concept(concept)
else:
concept = self._update_concept(concept)
return concept
[docs] def insert_concepts(self, concepts):
""" Insert/Update concepts into ASER
:param concepts: concepts to insert/update
:type concepts: List[aser.concept.ASERConcept]
:return: the inserted/updated concepts
:rtype: List[aser.concept.ASERConcept]
"""
results = []
new_concepts = []
existing_indices = []
existing_indices = []
existing_concepts = []
for idx, concept in enumerate(concepts):
if concept.cid not in self.cids:
new_concepts.append(concept)
results.append(concept)
else:
existing_indices.append(idx)
existing_concepts.append(concept)
results.append(None)
if len(new_concepts):
self._insert_concepts(new_concepts)
if len(existing_concepts):
for idx, updated_concept in enumerate(self._update_concepts(existing_concepts)):
results[existing_indices[idx]] = updated_concept
return results
[docs] def get_exact_match_concept(self, concept):
""" Retrieve a exact matched concept from ASER
(suggestion: consider to use `get_exact_match_concepts` if you want to retrieve multiple concepts)
:param concept: a concept that contains the cid
:type concept: Union[aser.concept.ASERConcept, Dict[str, object], str]
:return: the exact matched concept
:rtype: aser.concept.ASERConcept
"""
if isinstance(concept, ASERConcept):
cid = concept.cid
elif isinstance(concept, dict):
cid = concept["cid"]
elif isinstance(concept, str):
cid = concept
else:
raise ValueError("Error: conceptualize should be an instance of ASERConcept, a dictionary, or a cid.")
if cid not in self.cids:
return None
exact_match_concept = self.cid2concept_cache.get(cid, None)
if not exact_match_concept:
exact_match_concept = self._get_concept_and_store_in_cache(cid)
return exact_match_concept
[docs] def get_exact_match_concepts(self, concepts):
""" Retrieve multiple exact matched concepts from ASER
:param concepts: concepts
:type concepts: Union[List[aser.concept.ASERConcept], List[Dict[str, object]], List[str]]
:return: the exact matched concepts
:rtype: List[aser.concept.ASERConcept]
"""
exact_match_concepts = []
if len(concepts):
if isinstance(concepts[0], ASERConcept):
cids = [concept.cid for concept in concepts]
elif isinstance(concepts[0], dict):
cids = [concept["cid"] for concept in concepts]
elif isinstance(concepts[0], str):
cids = concepts
else:
raise ValueError("Error: concepts should instances of ASERConcept, dictionaries, or cids.")
missed_indices = []
missed_cids = []
for idx, cid in enumerate(cids):
if cid not in self.cids:
exact_match_concepts.append(None)
exact_match_concept = self.cid2concept_cache.get(cid, None)
exact_match_concepts.append(exact_match_concept)
if not exact_match_concept:
missed_indices.append(idx)
missed_cids.append(cid)
for idx, exact_match_concept in enumerate(self._get_concepts_and_store_in_cache(missed_cids)):
exact_match_concepts[missed_indices[idx]] = exact_match_concept
return exact_match_concepts
[docs] def get_concepts_by_keys(self, bys, keys, order_bys=None, reverse=False, top_n=None):
""" Retrieve multiple partial matched concepts by keys and values from ASER
:param bys: the given columns to match
:type bys: List[str]
:param keys: the given values to match
:type keys: List[str]
:param order_bys: the columns whose value are used to sort rows
:type order_bys: List[str]
:param reverse: whether to sort in a reversed order
:type reverse: bool
:param top_n: how many concepts to return, default `None` for all concepts
:type top_n: int
:return: the partial matched concepts
:rtype: List[aser.concept.Concepts]
"""
assert len(bys) == len(keys)
for i in range(len(bys) - 1, -1, -1):
if bys[i] not in self.concept_columns:
bys.pop(i)
keys.pop(i)
if len(bys) == 0:
return []
return list(
map(
self._convert_row_to_concept,
self._conn.get_rows_by_keys(
self.concept_table_name,
bys,
keys,
self.concept_columns,
order_bys=order_bys,
reverse=reverse,
top_n=top_n
)
)
)
[docs] def get_concept_given_str(self, concept_str):
""" Retrieve the exact matched concept given a string from ASER
:param concept_str: a string representation of a concept
:type concept_str: str
:return: the exact matched concept
:rtype: aser.concept.ASERConcept
"""
cid = ASERConcept.generate_cid(concept_str)
return self.get_exact_match_concept(cid)
[docs] def get_concepts_given_strs(self, concept_strs):
""" Retrieve the exact matched concepts given strings from ASER
:param concept_str: string representations of concepts
:type concept_str: List[str]
:return: the exact matched concepts
:rtype: List[aser.concept.ASERConcept]
"""
cids = list(map(ASERConcept.generate_cid, concept_strs))
return self.get_exact_match_concepts(cids)
"""
KG (Relations)
"""
def _convert_relation_to_row(self, relation):
row = OrderedDict({"_id": relation.rid})
for c in self.relation_columns[1:-len(relation_senses)]:
row[c] = getattr(relation, c)
for r in relation_senses:
row[r] = relation.relations.get(r, 0.0)
return row
def _convert_row_to_relation(self, row):
return Relation(
row["hid"], row["tid"], {r: cnt
for r, cnt in row.items() if isinstance(cnt, float) and cnt > 0.0}
)
[docs] def get_relation_columns(self, columns):
""" Get column information from relations
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
return self._conn.get_columns(self.relation_table_name, columns)
def _insert_relation(self, relation):
row = self._convert_relation_to_row(relation)
self._conn.insert_row(self.relation_table_name, row)
if self.mode == "insert":
self.rids.add(relation.rid)
elif self.mode == "cache":
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if getattr(relation, k) in v:
v[getattr(relation, k)].append(relation.rid)
elif self.mode == "memory":
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if getattr(relation, k) not in v:
v[getattr(relation, k)] = [relation.rid]
else:
v[getattr(relation, k)].append(relation.rid)
return relation
def _insert_relations(self, relations):
rows = list(map(self._convert_relation_to_row, relations))
self._conn.insert_rows(self.relation_table_name, rows)
if self.mode == "insert":
for relation in relations:
self.rids.add(relation.rid)
elif self.mode == "cache":
for relation in relations:
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if getattr(relation, k) in v:
v[getattr(relation, k)].append(relation.rid)
elif self.mode == "memory":
for relation in relations:
self.rids.add(relation.rid)
self.rid2relation_cache[relation.rid] = relation
for k, v in self.partial2rids_cache.items():
if getattr(relation, k) not in v:
v[getattr(relation, k)] = [relation.rid]
else:
v[getattr(relation, k)].append(relation.rid)
return relations
def _get_relation_and_store_in_cache(self, rid):
return self._get_relations_and_store_in_cache([rid])[0]
def _get_relations_and_store_in_cache(self, rids):
relations = list(
map(
self._convert_row_to_relation,
self._conn.select_rows(self.relation_table_name, rids, self.relation_columns)
)
)
for relation in relations:
if relation:
self.rid2relation_cache[relation.rid] = relation
return relations
def _update_relation(self, relation):
# find new relation frequencies
update_columns = []
for r in relation_senses:
if relation.relations.get(r, 0.0) > 0.0:
update_columns.append(r)
# update db
update_op = self._conn.get_update_op(update_columns, "+")
row = self._convert_relation_to_row(relation)
self._conn.update_row(self.relation_table_name, row, update_op, update_columns)
# update cache
updated_relation = self.rid2relation_cache.get(relation.rid, None)
if updated_relation:
for r in update_columns:
updated_relation.relation[r] += relation.relation[r]
else:
updated_relation = self._get_relation_and_store_in_cache(relation.rid)
return updated_relation
def _update_relations(self, relations):
# update db
update_op = self._conn.get_update_op(relation_senses, "+")
rows = list(map(self._convert_relation_to_row, relations))
self._conn.update_rows(self.relation_table_name, rows, update_op, relation_senses)
# update cache
updated_relations = []
missed_indices = []
missed_rids = []
for idx, relation in enumerate(relations):
if relation.rid not in self.rids:
updated_relations.append(None)
updated_relation = self.rid2relation_cache.get(relation.rid, None)
updated_relations.append(updated_relations)
if updated_relation:
for r in relation_senses:
if updated_relation.relations.get(r, 0.0) > 0.0:
updated_relation.relations[r] += relation.relations[r]
else:
missed_indices.append(idx)
missed_rids.append(relation.rid)
for idx, updated_relation in enumerate(self._get_relations_and_store_in_cache(missed_rids)):
updated_relations[missed_indices[idx]] = updated_relation
return updated_relations
[docs] def insert_relation(self, relation):
""" Insert/Update a relation into ASER
(suggestion: consider to use `insert_relations` if you want to insert multiple relations)
:param relation: a relation to insert/update
:type relation: aser.relation.Relation
:return: the inserted/updated relation
:rtype: aser.relation.Relation
"""
if relation.rid not in self.rid2relation_cache:
return self._insert_relation(relation)
else:
return self._update_relation(relation)
[docs] def insert_relations(self, relations):
""" Insert/Update relations into ASER
:param relations: relations to insert/update
:type relations: List[aser.relation.Relation]
:return: the inserted/updated relations
:rtype: List[aser.relation.Relation]
"""
results = []
new_relations = []
existing_indices = []
existing_relations = []
for idx, relation in enumerate(relations):
if relation.rid not in self.rids:
new_relations.append(relation)
results.append(relation)
else:
existing_indices.append(idx)
existing_relations.append(relation)
results.append(None)
if len(new_relations):
self._insert_relations(new_relations)
if len(existing_relations):
for idx, updated_relation in enumerate(self._update_relations(existing_relations)):
results[existing_indices[idx]] = updated_relation
return results
[docs] def get_exact_match_relation(self, relation):
""" Retrieve an exact matched relation from ASER
(suggestion: consider to use `get_exact_match_relations` if you want to retrieve multiple relations)
:param relation: a relation that contains the rid or a concept pair that contains two cids
:type relation: Union[aser.relation.Relation, Dict[str, object], str, Tuple[aser.concept.ASERConcept, aser.concept.ASERConcept], Tuple[str, str]]
:return: the exact matched relation
:rtype: aser.relation.Relation
"""
if isinstance(relation, Relation):
rid = relation.rid
elif isinstance(relation, dict):
rid = relation["rid"]
elif isinstance(relation, str):
rid = relation
elif isinstance(relation, (tuple, list)) and len(relation) == 2:
if isinstance(relation[0], ASERConcept) and isinstance(relation[1], ASERConcept):
rid = Relation.generate_rid(relation[0].cid, relation[1].cid)
elif isinstance(relation[0], str) and isinstance(relation[1], str):
rid = Relation.generate_rid(relation[0], relation[1])
else:
raise ValueError(
"Error: relation should be (an instance of ASERConcept, an instance of ASERConcept) or (hid, tid)."
)
else:
raise ValueError(
"Error: relation should be an instance of Relation, a dictionary, rid,"
"(an instance of ASERConcept, an instance of ASERConcept), or (hid, tid)."
)
if rid not in self.rids:
return None
exact_match_relation = self.rid2relation_cache.get(rid, None)
if not exact_match_relation:
exact_match_relation = self._get_relation_and_store_in_cache(rid)
return exact_match_relation
[docs] def get_exact_match_relations(self, relations):
""" Retrieve exact matched relations from ASER
:param relations: a relations that contain the rids or concept pairs each of which contains two cids
:type relations: Union[List[aser.relation.Relation], List[Dict[str, object]], List[str], List[Tuple[aser.concept.ASERConcept, aser.concept.ASERConcept]], List[Tuple[str, str]]]
:return: the exact matched relations
:rtype: List[aser.relation.Relation]
"""
exact_match_relations = []
if len(relations):
if isinstance(relations[0], Relation):
rids = [relation.rid for relation in relations]
elif isinstance(relations[0], dict):
rids = [relation["rid"] for relation in relations]
elif isinstance(relations[0], str):
rids = relations
elif isinstance(relations[0], (tuple, list)) and len(relations[0]) == 2:
if isinstance(relations[0][0], ASERConcept) and isinstance(relations[0][1], ASERConcept):
rids = [Relation.generate_rid(relation[0].cid, relation[1].cid) for relation in relations]
elif isinstance(relations[0][0], str) and isinstance(relations[0][1], str):
rids = [Relation.generate_rid(relation[0], relation[1]) for relation in relations]
else:
raise ValueError(
"Error: relations should be [(an instance of ASERConcept, an instance of ASERConcept), ...] or [(hid, tid), ...]."
)
else:
raise ValueError(
"Error: relations should be instances of Relation, dictionaries, rids, [(an instance of ASERConcept, an instance of ASERConcept), ...], or [(hid, tid), ...]."
)
missed_indices = []
missed_rids = []
for idx, rid in enumerate(rids):
if rid not in self.rids:
exact_match_relations.append(None)
exact_match_relation = self.rid2relation_cache.get(rid, None)
exact_match_relations.append(exact_match_relation)
if not exact_match_relation:
missed_indices.append(idx)
missed_rids.append(rid)
for idx, exact_match_relation in enumerate(self._get_relations_and_store_in_cache(missed_rids)):
exact_match_relations[missed_indices[idx]] = exact_match_relation
return exact_match_relations
[docs] def get_relations_by_keys(self, bys, keys, order_bys=None, reverse=False, top_n=None):
""" Retrieve multiple partial matched relations by keys and values from ASER
:param bys: the given columns to match
:type bys: List[str]
:param keys: the given values to match
:type keys: List[str]
:param order_bys: the columns whose value are used to sort rows
:type order_bys: Union[List[str], None] (default = None)
:param reverse: whether to sort in a reversed order
:type reverse: bool (default = False)
:param top_n: how many relations to return, default `None` for all relations
:type top_n: Union[int, None] (default = None)
:return: the partial matched relations
:rtype: List[aser.relation.Relation]
"""
assert len(bys) == len(keys)
for i in range(len(bys) - 1, -1, -1):
if bys[i] not in self.relation_columns:
bys.pop(i)
keys.pop(i)
if len(bys) == 0:
return []
cache = None
by_index = -1
for k in ["hid", "tid"]:
if k in bys and k in self.partial2rids_cache:
cache = self.partial2rids_cache[k]
by_index = bys.index(k)
break
if cache:
if keys[by_index] in cache:
key_match_relations = [self.rid2relation_cache[rid] for rid in cache[keys[by_index]]]
else:
if self.mode == "memory":
return []
key_cache = []
key_match_relations = list(
map(
self._convert_row_to_relation,
self._conn.get_rows_by_keys(
self.relation_table_name, [bys[by_index]], [keys[by_index]], self.relation_columns
)
)
)
for key_match_relation in key_match_relations:
if key_match_relation.rid not in self.rid2relation_cache:
self.rid2relation_cache[key_match_relation.rid] = key_match_relation
key_cache.append(key_match_relation.rid)
cache[keys[by_index]] = key_cache
for i in range(len(bys)):
if i == by_index:
continue
key_match_relations = list(filter(lambda x: x[bys[i]] == keys[i], key_match_relations))
if order_bys:
key_match_relations.sort(key=operator.itemgetter(*order_bys), reverse=reverse)
if top_n:
key_match_relations = key_match_relations[:top_n]
return key_match_relations
return list(
map(
self._convert_row_to_relation,
self._conn.get_rows_by_keys(
self.relation_table_name,
bys,
keys,
self.relation_columns,
order_bys=order_bys,
reverse=reverse,
top_n=top_n
)
)
)
"""
KG (ConceptInstancePairs)
"""
def _convert_concept_instance_pair_to_row(self, concept_instance_pair):
if isinstance(concept_instance_pair, ASERConceptInstancePair):
row = OrderedDict(
{
"_id": concept_instance_pair.pid,
"cid": concept_instance_pair.cid,
"eid": concept_instance_pair.eid,
"pattern": concept_instance_pair.pattern,
"score": concept_instance_pair.score
}
)
elif isinstance(concept_instance_pair, (list, tuple)) and len(concept_instance_pair) == 3:
pid = ASERConceptInstancePair.generate_pid(concept_instance_pair[0].cid, concept_instance_pair[1].eid)
row = OrderedDict(
{
"_id": pid,
"cid": concept_instance_pair[0].cid,
"eid": concept_instance_pair[1].eid,
"pattern": concept_instance_pair[1].pattern,
"score": concept_instance_pair[2]
}
)
return row
def _convert_row_to_concept_instance_pair(self, row):
return ASERConceptInstancePair(row["cid"], row["eid"], row["pattern"], row["score"])
[docs] def get_concept_instance_pair_columns(self, columns):
""" Get column information from concepts
:param columns: the columns to retrieve
:type columns: List[str]
:return: a list of retrieved rows
:rtype: List[Dict[str, object]]
"""
return self._conn.get_columns(self.concept_instance_pair_columns, columns)
def _insert_concept_instance_pair(self, concept_instance_pair):
row = self._convert_concept_instance_pair_to_row(concept_instance_pair)
self._conn.insert_row(self.concept_instance_pair_table_name, row)
if self.mode == "insert":
self.eids.add(concept_instance_pair.eid)
elif self.mode == "cache":
self.eids.add(concept_instance_pair.eid)
if concept_instance_pair.cid in self.cid2eid_pattern_scores:
self.cid2eid_pattern_scores[concept_instance_pair.cid].append(
(concept_instance_pair.eid, concept_instance_pair.pattern, concept_instance_pair.score)
)
if concept_instance_pair.eid in self.eid2cid_scores:
self.eid2cid_scores[concept_instance_pair.eid].append(
(concept_instance_pair.cid, concept_instance_pair.score)
)
elif self.mode != "memory":
self.eids.add(concept_instance_pair.eid)
if concept_instance_pair.cid not in self.cid2eid_pattern_scores:
self.cid2eid_pattern_scores[concept_instance_pair.cid] = [
(concept_instance_pair.eid, concept_instance_pair.pattern, concept_instance_pair.score)
]
else:
self.cid2eid_pattern_scores[concept_instance_pair.cid].append(
(concept_instance_pair.eid, concept_instance_pair.pattern, concept_instance_pair.score)
)
if concept_instance_pair.eid not in self.eid2cid_scores:
self.eid2cid_scores[concept_instance_pair.eid] = [
(concept_instance_pair.cid, concept_instance_pair.score)
]
else:
self.eid2cid_scores[concept_instance_pair.eid].append(
(concept_instance_pair.cid, concept_instance_pair.score)
)
return self._convert_row_to_concept_instance_pair(row)
def _insert_concept_instance_pairs(self, concept_instance_pairs):
rows = list(map(self._convert_concept_instance_pair_to_row, concept_instance_pairs))
self._conn.insert_rows(self.concept_instance_pair_table_name, rows)
if self.mode == "insert":
for concept_instance_pair in concept_instance_pairs:
self.eids.add(concept_instance_pair.eid)
elif self.mode == "cache":
for concept_instance_pair in concept_instance_pairs:
self.eids.add(concept_instance_pair.eid)
if concept_instance_pair.cid in self.cid2eid_pattern_scores:
self.cid2eid_pattern_scores[concept_instance_pair.cid].append(
(concept_instance_pair.eid, concept_instance_pair.pattern, concept_instance_pair.score)
)
if concept_instance_pair.eid in self.eid2cid_scores:
self.eid2cid_scores[concept_instance_pair.eid].append(
(concept_instance_pair.cid, concept_instance_pair.score)
)
elif self.mode == "memory":
for concept_instance_pair in concept_instance_pairs:
self.eids.add(concept_instance_pair.eid)
if concept_instance_pair.cid not in self.cid2eid_pattern_scores:
self.cid2eid_pattern_scores[concept_instance_pair.cid] = [
(concept_instance_pair.eid, concept_instance_pair.pattern, concept_instance_pair.score)
]
else:
self.cid2eid_pattern_scores[concept_instance_pair.cid].append(
(concept_instance_pair.eid, concept_instance_pair.pattern, concept_instance_pair.score)
)
if concept_instance_pair.eid not in self.eid2cid_scores:
self.eid2cid_scores[concept_instance_pair.eid] = [
(concept_instance_pair.cid, concept_instance_pair.score)
]
else:
self.eid2cid_scores[concept_instance_pair.eid].append(
(concept_instance_pair.cid, concept_instance_pair.score)
)
return [self._convert_row_to_concept_instance_pair(row) for row in rows]
def _update_concept_instance_pair(self, concept_instance_pair):
# update db
update_op = self._conn.get_update_op(["score"], "+")
row = self._convert_concept_instance_pair_to_row(concept_instance_pair)
self._conn.update_row(self.concept_instance_pair_table_name, row, update_op, ["score"])
# updata cache
updated_score = None
if self.mode == "insert":
return None # don"t care
cached_cid_scores = self.eid2cid_scores.get(concept_instance_pair.eid, None)
if cached_cid_scores:
for idx, cid_score in enumerate(cached_cid_scores):
if concept_instance_pair.cid == cid_score[0]:
updated_score = cid_score[1] + concept_instance_pair.score
cached_cid_scores[idx] = (cid_score[0], updated_score)
break
cached_eid_pattern_scores = self.cid2eid_pattern_scores.get(concept_instance_pair.cid, None)
if cached_eid_pattern_scores:
for idx, eid_pattern_score in enumerate(cached_eid_pattern_scores):
if concept_instance_pair.eid == eid_pattern_score[0]:
updated_score = eid_pattern_score[2] + concept_instance_pair.score
cached_eid_pattern_scores[idx] = (eid_pattern_score[0], eid_pattern_score[1], updated_score)
break
if updated_score is None:
updated_score = self._conn.select_row(self.concept_instance_pair_table_name, row["_id"], ["score"])["score"]
return ASERConceptInstancePair(
concept_instance_pair.cid, concept_instance_pair.eid, concept_instance_pair.pattern, updated_score
)
def _update_concept_instance_pairs(self, concept_instance_pairs):
# update db
update_op = self._conn.get_update_op(["score"], "+")
rows = list(map(self._convert_concept_instance_pair_to_row, concept_instance_pairs))
self._conn.update_rows(self.concept_instance_pair_table_name, rows, update_op, ["score"])
# update cache
if self.mode == "insert":
return [None] * len(concept_instance_pairs) # don"t care
results = []
updated_scores = []
missed_indices = []
missed_pids = []
for idx, concept_instance_pair in enumerate(concept_instance_pairs):
cached_cid_scores = self.eid2cid_scores.get(concept_instance_pair.eid, None)
if cached_cid_scores:
for idx, cid_score in enumerate(cached_cid_scores):
if concept_instance_pair.cid == cid_score[0]:
updated_score = cid_score[1] + concept_instance_pair.score
cached_cid_scores[idx] = (cid_score[0], updated_score)
break
cached_eid_pattern_scores = self.cid2eid_pattern_scores.get(concept_instance_pair.cid, None)
if cached_eid_pattern_scores:
for idx, eid_pattern_score in enumerate(cached_eid_pattern_scores):
if concept_instance_pair.eid == eid_pattern_score[0]:
updated_score = eid_pattern_score[2] + concept_instance_pair.score
cached_eid_pattern_scores[idx] = (eid_pattern_score[0], eid_pattern_score[1], updated_score)
break
if updated_score is None:
missed_indices.append(idx)
updated_scores.append(None)
missed_pids.append(concept_instance_pair.pid)
else:
updated_scores.append(updated_score)
if len(missed_indices):
for idx, updated_row in enumerate(
self._conn.select_rows(self.concept_instance_pair_table_name, missed_pids, ["score"])
):
updated_scores[missed_indices[idx]] = updated_row["score"]
return [
ASERConceptInstancePair(
concept_instance_pair.cid, concept_instance_pair.eid, concept_instance_pair.pattern, updated_score
) for concept_instance_pair, updated_score in zip(concept_instance_pairs, updated_score)
]
[docs] def insert_concept_instance_pair(self, concept_instance_pair):
"""Insert/Update a concept_instance_pair into ASER
(suggestion: consider to use `insert_concept_instance_pairs` if you want to insert multiple pairs)
:param concept_instance_pair: a concept-instance pair to insert/update
:type concept_instance_pair: Union[aser.concept.ASERConceptInstancePair, Tuple[aser.concept.ASERConcpet, aser.event.Eventuality, float]]
:return: the inserted/updated concept-instance pair
:rtype: aser.concept.ASERConceptInstancePair
"""
if not isinstance(concept_instance_pair, ASERConceptInstancePair):
concept_instance_pair = ASERConceptInstancePair(
concept_instance_pair[0].cid,
concept_instance_pair[1].eid,
concept_instance_pair[1].pattern,
concept_instance_pair[2]
)
if concept_instance_pair.cid in self.cids and concept_instance_pair.eid in self.eids:
return self._update_concept_instance_pair(concept_instance_pair)
else:
return self._insert_concept_instance_pair(concept_instance_pair)
[docs] def insert_concept_instance_pairs(self, concept_instance_pairs):
"""Insert/Update concept_instance_pairs into ASER
:param concept_instance_pairs: concept-instance pairs to insert/update
:type concept_instance_pairs: Union[List[aser.concept.ASERConceptInstancePair], List[Tuple[aser.concept.ASERConcpet, aser.event.Eventuality, float]]]
:return: the inserted/updated concept-instance pairs
:rtype: List[aser.concept.ASERConceptInstancePair]
"""
results = [None] * len(concept_instance_pairs)
new_concept_instance_pairs = []
existing_indices = []
existing_concept_instance_pairs = []
for idx, concept_instance_pair in enumerate(concept_instance_pairs):
if not isinstance(concept_instance_pair, ASERConceptInstancePair):
concept_instance_pair = ASERConceptInstancePair(
concept_instance_pair[0].cid,
concept_instance_pair[1].eid,
concept_instance_pair[1].pattern,
concept_instance_pair[2]
)
if concept_instance_pair.cid in self.cids and concept_instance_pair.eid in self.eids:
existing_indices.append(idx)
existing_concept_instance_pairs.append(concept_instance_pair)
results.append(None)
else:
new_concept_instance_pairs.append(concept_instance_pair)
results.append(concept_instance_pair)
if len(new_concept_instance_pairs):
self._insert_concept_instance_pairs(new_concept_instance_pairs)
if len(existing_indices):
for idx, updated_pair in enumerate(self._update_concept_instance_pairs(existing_concept_instance_pairs)):
results[existing_indices[idx]] = updated_pair
return results
[docs] def get_eventualities_given_concept(self, concept):
""" Retrieve original eventualities given a concept from ASER
:param concept: concept that corresponds to some eventualities
:type concept: Union[aser.concept.ASERConcpet, Dict[str, object], str]
:return: the linked eventualities
:rtype: List[aser.eventuality.Eventuality]
"""
if self.mode == "insert":
return []
if isinstance(concept, ASERConcept):
cid = concept.cid
elif isinstance(concept, dict):
cid = concept["cid"]
elif isinstance(concept, str):
cid = concept
else:
raise ValueError("Error: conceptualize should be an instance of ASERConcept, a dictionary, or a cid.")
cached_eid_pattern_scores = self.cid2eid_pattern_scores.get(cid, None)
if cached_eid_pattern_scores:
return cached_eid_pattern_scores
else:
eid_pattern_scores = self._conn.get_rows_by_keys(
self.concept_instance_pair_table_name, bys=["cid"], keys=[cid], columns=["eid", "pattern", "score"]
)
return eid_pattern_scores
[docs] def get_concepts_given_eventuality(self, eventuality):
""" Retrieve concepts given an eventuality from ASER
:param eventuality: eventuality that conceptualizes to the given concept
:type eventuality: Union[aser.eventuality.Eventuality, Dict[str, object], str]
:return: the linked concepts
:rtype: List[aser.concept.ASERConcepts]
"""
if self.mode == "insert":
return []
if isinstance(eventuality, Eventuality):
eid = eventuality.eid
elif isinstance(eventuality, dict):
eid = eventuality["eid"]
elif isinstance(eventuality, str):
eid = eventuality
else:
raise ValueError("Error: conceptualize should be an instance of Eventuality, a dictionary, or a eid.")
cached_cid_scores = self.eid2cid_scores.get(eid, None)
if cached_cid_scores:
cids = [cid_score[0] for cid_score in cached_cid_scores]
scores = [cid_score[1] for cid_score in cached_cid_scores]
else:
cid_scores = self._conn.get_rows_by_keys(
self.concept_instance_pair_table_name, bys=["eid"], keys=[eid], columns=["cid", "score"]
)
cids = [cid_score["cid"] for cid_score in cid_scores]
scores = [cid_score["score"] for cid_score in cid_scores]
concepts = self.get_exact_match_concepts(cids)
return list(zip(concepts, scores))
"""
Additional APIs
"""