Python API
Subpackages
srtk.link_wikidata module
Entity linking This step links the entity mentions in the question to the entities in the Wikidata knowledge graph. It inference on the REL endpoint.
- srtk.link.link(args)
Link the entities in the questions to the Wikidata knowledge graph
srtk.preprocess module
This script creates the training data from the grounded questions.
Inputs should be a jsonl file, with each line representing a grounded question. The format of each line should be like this example:
{
"id": "sample-id",
"question": "Which universities did Barack Obama graduate from?",
"question_entities": [
"Q76"
],
"answer_entities": [
"Q49122",
"Q1346110",
"Q4569677"
]
}
srtk.retrieve module
This script retrieves subgraphs from a knowledge graph according to a natural language query (usually a question). This command can also be used to evaluate a trained retriever when the answer entities are known.
The expected fields of one sample are: - question: question text - question_entities: list of grounded question entities (ids)
For evaluation, the following field is also required: - answer_entities: list of grounded answer entities (ids)
- class srtk.retrieve.KnowledgeGraphTraverser(kg: KnowledgeGraphBase)
KnowledgeGraphTraverser is a helper class that traverses a knowledge graph
- deduce_leaf_relations(entity, path)
Deduce leaf relations from an entity following a path hop by hop.
- Parameters:
entity (str) – the identifier of the source node
path (list[str]) – a list of relation identifiers
- Returns:
- a tuple of relations that are n-hop away from the source node,
where n is the length of the path
- Return type:
tuple[str]
- deduce_leaves(entity, path)
Deduce leaves from an entity following a path hop by hop.
- Parameters:
entity (str) – the identifier of the source node
path (list[str]) – a list of relation identifiers
- Returns:
- a set of leave identifiers that are n-hop away from the source node,
where n is the length of the path
- Return type:
set[str]
- get_relation_label(identifier)
Get the relation label of an entity or a relation.
It serves as a proxy to the knowledge graph’s get_label function. For freebase, we directly use the identifier as the label. For others, we return the retrieved label if it exists, otherwise return the identifier.
- Parameters:
identifier (str) – the identifier of an entity or a relation
- Returns:
the label of the entity or the relation
- Return type:
str
- retrive_subgraph(entity, path)
Retrive subgraph entities and triplets by traversing from an entity following a relation path hop by hop.
- Parameters:
entity (str) – the identifier of the source node
path (list[str]) – a list of relation identifiers
- Returns:
a list of entity identifiers triplets: a list of triplets
- Return type:
entities
- class srtk.retrieve.Path(prev_relations, score)
- prev_relations
Alias for field number 0
- score
Alias for field number 1
- class srtk.retrieve.Retriever(kg: KnowledgeGraphBase, scorer: Scorer, beam_width: int, max_depth: int)
Retriever retrieves subgraphs from a knowledge graph with a question and its linked entities. The retrieval process takes the semantic information of the question and the expanding path into consideration.
- beam_search_path(question: str, question_entity: str)
This function reimplement RUC’s paper’s solution. In the search process, only the history paths are recorded; each new relation is looked up via looking up the end relations from the question entities following a history path.
- Parameters:
question (str) – a natural language question
question_entity (str) – a grounded question entity
- Returns:
path score list, a list of (path, score) tuples
- Return type:
list[Path]
- expand_and_score_paths(question: str, paths: List[Path], relations_batched: List[List[str]]) List[Path]
Expand the paths by one hop and score them by comparing the embedding similarity between the query (question + prev_relations) and the next relation.
- retrieve_subgraph_triplets(sample: Dict[str, Any])
Retrieve triplets as subgraphs from paths.
- Parameters:
sample (dict) – a sample from the dataset, which contains at least the following fields: question: a string question_entities: a list of entities
- Returns:
a list of triplets
- Return type:
list(tuple)
- srtk.retrieve.calculate_hit_and_miss(retrieved_path)
Calculate the recall of answer entities in retrieved triplets, if answer_entities exists in each sample.
- Parameters:
retrieved_path (str) – path to the retrieved triplets
- Returns:
- number of samples that have at least one answer entity in retrieved triplets,
and number of samples that have no answer entity in retrieved triplets
- Return type:
tuple(int, int)
- srtk.retrieve.calculate_subgraph_size(retrieved_path)
Calculate the average number of triplets, entities and relations in retrieved subgraphs.
- Parameters:
retrieved_path (str) – path to the retrieved triplets
- Returns:
average number of triplets, entities and relations in retrieved subgraphs
- Return type:
tuple
- srtk.retrieve.print_and_save_recall(retrieved_path)
Calculate and print the recall of answer entities in retrieved triplets, If any answer from the answer entities is in the retrieved entities, the sample counts as a hit.
- srtk.retrieve.retrieve(args)
Retrieve subgraphs from a knowledge graph.
- Parameters:
args (Namespace) – arguments for subgraph retrieval
srtk.train module
The script to train the scorer model.
e.g. python train.py –data-file data/train.jsonl –model-name-or-path intfloat/e5-small –save-model-path artifacts/scorer
- class srtk.train.Collator(tokenizer: transformers.PreTrainedTokenizerBase)
Bases:
objectCollate a list of examples into a batch.
- tokenizer: transformers.PreTrainedTokenizerBase
- srtk.train.concate_all(example)
Concatenate all columns into one column for input. The resulted ‘input_text’ column is a list of strings.
- srtk.train.prepare_dataloaders(train_data, validation_data, tokenizer, batch_size)
Prepare dataloaders for training and validation.
If validation dataset is not provided, 5 percent of the training data will be used as validation data.
- srtk.train.train(args)
Train the scorer model.
The model compares the similarity between [question; previous relation] and the next relation.
srtk.visualize module
Visualize the graph (represented as a set of triplets) using pyvis. The visualized subgraphs are html files.
- srtk.visualize.visualize(args)
Main entry for subgraph visualization.
- Parameters:
args (Namespace) – arguments for subgraph visualization.
- srtk.visualize.visualize_subgraph(sample, kg: KnowledgeGraphBase)
Visualize the subgraph. It returns an html string.