Skip to content

Constrained cfg

neps.search_spaces.architecture.cfg_variants.constrained_cfg #

ConstrainedGrammar #

ConstrainedGrammar(*args, **kwargs)

Bases: Grammar

Source code in neps/search_spaces/architecture/cfg_variants/constrained_cfg.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.constraints = None
    self.none_operation = None
    self.constraint_is_class: bool = False

    self._prior: dict = None

compute_space_size property #

compute_space_size: int

Computes the size of the space described by the grammar.

PARAMETER DESCRIPTION
primitive_nonterminal

The primitive nonterminal of the grammar. Defaults to "OPS".

TYPE: str

RETURNS DESCRIPTION
int

size of space described by grammar.

TYPE: int

rand_subtree #

rand_subtree(tree: str) -> Tuple[str, int]

Helper function to choose a random subtree in a given parse tree. Runs a single pass through the tree (stored as string) to look for the location of swappable nonterminal symbols.

PARAMETER DESCRIPTION
tree

parse tree.

TYPE: str

RETURNS DESCRIPTION
Tuple[str, int]

Tuple[str, int]: return the parent node of the subtree and its index.

Source code in neps/search_spaces/architecture/cfg.py
def rand_subtree(self, tree: str) -> Tuple[str, int]:
    """Helper function to choose a random subtree in a given parse tree.
    Runs a single pass through the tree (stored as string) to look for
    the location of swappable nonterminal symbols.

    Args:
        tree (str): parse tree.

    Returns:
        Tuple[str, int]: return the parent node of the subtree and its index.
    """
    split_tree = tree.split(" ")
    swappable_indices = [
        i
        for i in range(0, len(split_tree))
        if split_tree[i][1:] in self.swappable_nonterminals
    ]
    r = np.random.randint(1, len(swappable_indices))
    chosen_non_terminal = split_tree[swappable_indices[r]][1:]
    chosen_non_terminal_index = swappable_indices[r]
    return chosen_non_terminal, chosen_non_terminal_index

remove_subtree staticmethod #

remove_subtree(
    tree: str, index: int
) -> Tuple[str, str, str]

Helper functioon to remove a subtree from a parse tree given its index. E.g. '(S (S (T 2)) (ADD +) (T 1))' becomes '(S (S (T 2)) ', '(T 1))' after removing (ADD +)

PARAMETER DESCRIPTION
tree

parse tree

TYPE: str

index

index of the subtree root node

TYPE: int

RETURNS DESCRIPTION
Tuple[str, str, str]

Tuple[str, str, str]: part before the subtree, subtree, part past subtree

Source code in neps/search_spaces/architecture/cfg.py
@staticmethod
def remove_subtree(tree: str, index: int) -> Tuple[str, str, str]:
    """Helper functioon to remove a subtree from a parse tree
    given its index.
    E.g. '(S (S (T 2)) (ADD +) (T 1))'
    becomes '(S (S (T 2)) ', '(T 1))'  after removing (ADD +)

    Args:
        tree (str): parse tree
        index (int): index of the subtree root node

    Returns:
        Tuple[str, str, str]: part before the subtree, subtree, part past subtree
    """
    split_tree = tree.split(" ")
    pre_subtree = " ".join(split_tree[:index]) + " "
    #  get chars to the right of split
    right = " ".join(split_tree[index + 1 :])
    # remove chosen subtree
    # single pass to find the bracket matching the start of the split
    counter, current_index = 1, 0
    for char in right:
        if char == "(":
            counter += 1
        elif char == ")":
            counter -= 1
        if counter == 0:
            break
        current_index += 1
    post_subtree = right[current_index + 1 :]
    removed = "".join(split_tree[index]) + " " + right[: current_index + 1]
    return (pre_subtree, removed, post_subtree)