Skip to content

Cfg

neps.search_spaces.architecture.cfg #

Grammar #

Grammar(*args, **kwargs)

Bases: CFG

Extended context free grammar (CFG) class from the NLTK python package We have provided functionality to sample from the CFG. We have included generation capability within the class (before it was an external function) Also allow sampling to return whole trees (not just the string of terminals).

Source code in neps/search_spaces/architecture/cfg.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    # store some extra quantities needed later
    non_unique_nonterminals = [str(prod.lhs()) for prod in self.productions()]
    self.nonterminals = list(set(non_unique_nonterminals))
    self.terminals = list(
        {str(individual) for prod in self.productions() for individual in prod.rhs()}
        - set(self.nonterminals)
    )
    # collect nonterminals that are worth swapping when doing genetic operations (i.e not those with a single production that leads to a terminal)
    self.swappable_nonterminals = list(
        {i for i in non_unique_nonterminals if non_unique_nonterminals.count(i) > 1}
    )

    self._prior = None

    if len(set(self.terminals).intersection(set(self.nonterminals))) > 0:
        raise Exception(
            f"Same terminal and nonterminal symbol: {set(self.terminals).intersection(set(self.nonterminals))}!"
        )
    for nt in self.nonterminals:
        if len(self.productions(Nonterminal(nt))) == 0:
            raise Exception(f"There is no production for nonterminal {nt}")

mutate #

mutate(
    parent: str,
    subtree_index: int,
    subtree_node: str,
    patience: int = 50,
) -> str

Grammar-based mutation, i.e., we sample a new subtree from a nonterminal node in the parse tree.

PARAMETER DESCRIPTION
parent

parent of the mutation.

TYPE: str

subtree_index

index pointing to the node that is root of the subtree.

TYPE: int

subtree_node

nonterminal symbol of the node.

TYPE: str

patience

Number of tries. Defaults to 50.

TYPE: int DEFAULT: 50

RETURNS DESCRIPTION
str

mutated child from parent.

TYPE: str

Source code in neps/search_spaces/architecture/cfg.py
def mutate(
    self, parent: str, subtree_index: int, subtree_node: str, patience: int = 50
) -> str:
    """Grammar-based mutation, i.e., we sample a new subtree from a nonterminal
    node in the parse tree.

    Args:
        parent (str): parent of the mutation.
        subtree_index (int): index pointing to the node that is root of the subtree.
        subtree_node (str): nonterminal symbol of the node.
        patience (int, optional): Number of tries. Defaults to 50.

    Returns:
        str: mutated child from parent.
    """
    # chop out subtree
    pre, _, post = self.remove_subtree(parent, subtree_index)
    _patience = patience
    while _patience > 0:
        # only sample subtree -> avoids full sampling of large parse trees
        new_subtree = self.sampler(1, start_symbol=subtree_node)[0]
        child = pre + new_subtree + post
        if parent != child:  # ensure that parent is really mutated
            break
        _patience -= 1

    return child.strip()

rand_subtree #

rand_subtree(tree: str) -> tuple[str, int]

Helper function to choose a random subtree in a given parse tree. Runs a single pass through the tree (stored as string) to look for the location of swappable nonterminal symbols.

PARAMETER DESCRIPTION
tree

parse tree.

TYPE: str

RETURNS DESCRIPTION
tuple[str, int]

Tuple[str, int]: return the parent node of the subtree and its index.

Source code in neps/search_spaces/architecture/cfg.py
def rand_subtree(self, tree: str) -> tuple[str, int]:
    """Helper function to choose a random subtree in a given parse tree.
    Runs a single pass through the tree (stored as string) to look for
    the location of swappable nonterminal symbols.

    Args:
        tree (str): parse tree.

    Returns:
        Tuple[str, int]: return the parent node of the subtree and its index.
    """
    split_tree = tree.split(" ")
    swappable_indices = [
        i
        for i in range(len(split_tree))
        if split_tree[i][1:] in self.swappable_nonterminals
    ]
    r = np.random.randint(1, len(swappable_indices))
    chosen_non_terminal = split_tree[swappable_indices[r]][1:]
    chosen_non_terminal_index = swappable_indices[r]
    return chosen_non_terminal, chosen_non_terminal_index

remove_subtree staticmethod #

remove_subtree(
    tree: str, index: int
) -> tuple[str, str, str]

Helper functioon to remove a subtree from a parse tree given its index. E.g. '(S (S (T 2)) (ADD +) (T 1))' becomes '(S (S (T 2)) ', '(T 1))' after removing (ADD +).

PARAMETER DESCRIPTION
tree

parse tree

TYPE: str

index

index of the subtree root node

TYPE: int

RETURNS DESCRIPTION
tuple[str, str, str]

Tuple[str, str, str]: part before the subtree, subtree, part past subtree

Source code in neps/search_spaces/architecture/cfg.py
@staticmethod
def remove_subtree(tree: str, index: int) -> tuple[str, str, str]:
    """Helper functioon to remove a subtree from a parse tree
    given its index.
    E.g. '(S (S (T 2)) (ADD +) (T 1))'
    becomes '(S (S (T 2)) ', '(T 1))'  after removing (ADD +).

    Args:
        tree (str): parse tree
        index (int): index of the subtree root node

    Returns:
        Tuple[str, str, str]: part before the subtree, subtree, part past subtree
    """
    split_tree = tree.split(" ")
    pre_subtree = " ".join(split_tree[:index]) + " "
    #  get chars to the right of split
    right = " ".join(split_tree[index + 1 :])
    # remove chosen subtree
    # single pass to find the bracket matching the start of the split
    counter, current_index = 1, 0
    for char in right:
        if char == "(":
            counter += 1
        elif char == ")":
            counter -= 1
        if counter == 0:
            break
        current_index += 1
    post_subtree = right[current_index + 1 :]
    removed = "".join(split_tree[index]) + " " + right[: current_index + 1]
    return (pre_subtree, removed, post_subtree)