Python API

Subpackages

srtk.preprocess module

This script creates the training data from the grounded questions.

Inputs should be a jsonl file, with each line representing a grounded question. The format of each line should be like this example:

{
    "id": "sample-id",
    "question": "Which universities did Barack Obama graduate from?",
    "question_entities": [
        "Q76"
    ],
    "answer_entities": [
        "Q49122",
        "Q1346110",
        "Q4569677"
    ]
}

srtk.retrieve module

This script retrieves subgraphs from a knowledge graph according to a natural language query (usually a question). This command can also be used to evaluate a trained retriever when the answer entities are known.

The expected fields of one sample are: - question: question text - question_entities: list of grounded question entities (ids)

For evaluation, the following field is also required: - answer_entities: list of grounded answer entities (ids)

class srtk.retrieve.KnowledgeGraphTraverser(kg: KnowledgeGraphBase)

KnowledgeGraphTraverser is a helper class that traverses a knowledge graph

deduce_leaf_relations(entity, path)

Deduce leaf relations from an entity following a path hop by hop.

Parameters:
  • entity (str) – the identifier of the source node

  • path (list[str]) – a list of relation identifiers

Returns:

a tuple of relations that are n-hop away from the source node,

where n is the length of the path

Return type:

tuple[str]

deduce_leaves(entity, path)

Deduce leaves from an entity following a path hop by hop.

Parameters:
  • entity (str) – the identifier of the source node

  • path (list[str]) – a list of relation identifiers

Returns:

a set of leave identifiers that are n-hop away from the source node,

where n is the length of the path

Return type:

set[str]

get_relation_label(identifier)

Get the relation label of an entity or a relation.

It serves as a proxy to the knowledge graph’s get_label function. For freebase, we directly use the identifier as the label. For others, we return the retrieved label if it exists, otherwise return the identifier.

Parameters:

identifier (str) – the identifier of an entity or a relation

Returns:

the label of the entity or the relation

Return type:

str

retrive_subgraph(entity, path)

Retrive subgraph entities and triplets by traversing from an entity following a relation path hop by hop.

Parameters:
  • entity (str) – the identifier of the source node

  • path (list[str]) – a list of relation identifiers

Returns:

a list of entity identifiers triplets: a list of triplets

Return type:

entities

class srtk.retrieve.Path(prev_relations, score)
prev_relations

Alias for field number 0

score

Alias for field number 1

class srtk.retrieve.Retriever(kg: KnowledgeGraphBase, scorer: Scorer, beam_width: int, max_depth: int)

Retriever retrieves subgraphs from a knowledge graph with a question and its linked entities. The retrieval process takes the semantic information of the question and the expanding path into consideration.

beam_search_path(question: str, question_entity: str)

This function reimplement RUC’s paper’s solution. In the search process, only the history paths are recorded; each new relation is looked up via looking up the end relations from the question entities following a history path.

Parameters:
  • question (str) – a natural language question

  • question_entity (str) – a grounded question entity

Returns:

path score list, a list of (path, score) tuples

Return type:

list[Path]

expand_and_score_paths(question: str, paths: List[Path], relations_batched: List[List[str]]) List[Path]

Expand the paths by one hop and score them by comparing the embedding similarity between the query (question + prev_relations) and the next relation.

Parameters:
  • question (str) –

  • paths (list[Path]) – a list of current paths

  • relations_batched (list[list[str]]) – a list of next relations for each path

Returns:

scored_paths, a list of newly expanded and scored paths

Return type:

list[Path]

retrieve_subgraph_triplets(sample: Dict[str, Any])

Retrieve triplets as subgraphs from paths.

Parameters:

sample (dict) – a sample from the dataset, which contains at least the following fields: question: a string question_entities: a list of entities

Returns:

a list of triplets

Return type:

list(tuple)

srtk.retrieve.calculate_hit_and_miss(retrieved_path)

Calculate the recall of answer entities in retrieved triplets, if answer_entities exists in each sample.

Parameters:

retrieved_path (str) – path to the retrieved triplets

Returns:

number of samples that have at least one answer entity in retrieved triplets,

and number of samples that have no answer entity in retrieved triplets

Return type:

tuple(int, int)

srtk.retrieve.calculate_subgraph_size(retrieved_path)

Calculate the average number of triplets, entities and relations in retrieved subgraphs.

Parameters:

retrieved_path (str) – path to the retrieved triplets

Returns:

average number of triplets, entities and relations in retrieved subgraphs

Return type:

tuple

srtk.retrieve.print_and_save_recall(retrieved_path)

Calculate and print the recall of answer entities in retrieved triplets, If any answer from the answer entities is in the retrieved entities, the sample counts as a hit.

srtk.retrieve.retrieve(args)

Retrieve subgraphs from a knowledge graph.

Parameters:

args (Namespace) – arguments for subgraph retrieval

srtk.train module

The script to train the scorer model.

e.g. python train.py –data-file data/train.jsonl –model-name-or-path intfloat/e5-small –save-model-path artifacts/scorer

class srtk.train.Collator(tokenizer: transformers.PreTrainedTokenizerBase)

Bases: object

Collate a list of examples into a batch.

tokenizer: transformers.PreTrainedTokenizerBase
srtk.train.concate_all(example)

Concatenate all columns into one column for input. The resulted ‘input_text’ column is a list of strings.

srtk.train.prepare_dataloaders(train_data, validation_data, tokenizer, batch_size)

Prepare dataloaders for training and validation.

If validation dataset is not provided, 5 percent of the training data will be used as validation data.

srtk.train.train(args)

Train the scorer model.

The model compares the similarity between [question; previous relation] and the next relation.

srtk.visualize module

Visualize the graph (represented as a set of triplets) using pyvis. The visualized subgraphs are html files.

srtk.visualize.visualize(args)

Main entry for subgraph visualization.

Parameters:

args (Namespace) – arguments for subgraph visualization.

srtk.visualize.visualize_subgraph(sample, kg: KnowledgeGraphBase)

Visualize the subgraph. It returns an html string.