Skip to content

Core graph grammar

neps.search_spaces.architecture.core_graph_grammar #

CoreGraphGrammar #

CoreGraphGrammar(
    grammars: list[Grammar] | Grammar,
    terminal_to_op_names: dict,
    terminal_to_graph_edges: dict = None,
    edge_attr: bool = True,
    edge_label: str = "op_name",
    zero_op: list = None,
    identity_op: list = None,
    name: str = None,
    scope: str = None,
    return_all_subgraphs: bool = False,
    return_graph_per_hierarchy: bool = False,
)

Bases: Graph

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def __init__(
    self,
    grammars: list[Grammar] | Grammar,
    terminal_to_op_names: dict,
    terminal_to_graph_edges: dict = None,
    edge_attr: bool = True,
    edge_label: str = "op_name",
    zero_op: list = None,
    identity_op: list = None,
    name: str = None,
    scope: str = None,
    return_all_subgraphs: bool = False,
    return_graph_per_hierarchy: bool = False,
):
    super().__init__(name, scope)

    self.grammars = [grammars] if isinstance(grammars, Grammar) else grammars

    self.terminal_to_op_names = terminal_to_op_names

    grammar_terminals = {
        terminal for grammar in self.grammars for terminal in grammar.terminals
    }
    diff_terminals = grammar_terminals - set(self.terminal_to_op_names.keys())
    if len(diff_terminals) != 0:
        raise Exception(
            f"Terminals {diff_terminals} not defined in primitive mapping!"
        )

    if terminal_to_graph_edges is None:  # only compute it once -> more efficient
        self.terminal_to_graph_edges = get_edge_lists_of_topologies(
            self.terminal_to_op_names
        )
    else:
        self.terminal_to_graph_edges = terminal_to_graph_edges
    self.edge_attr = edge_attr
    self.edge_label = edge_label

    self.zero_op = zero_op if zero_op is not None else []
    self.identity_op = identity_op if identity_op is not None else []

    self.terminal_to_graph_nodes: dict = {}

    self.return_all_subgraphs = return_all_subgraphs
    self.return_graph_per_hierarchy = return_graph_per_hierarchy

OPTIMIZER_SCOPE class-attribute instance-attribute #

OPTIMIZER_SCOPE = 'all'

Whether the search space has an interface to one of the tabular benchmarks which can then be used to query architecture performances.

If this is set to true then query() should be implemented.

__hash__ #

__hash__()

As it is very complicated to compare graphs (i.e. check all edge attributes, do the have shared attributes, ...) use just the name for comparison.

This is used when determining whether two instances are copies.

Source code in neps/search_spaces/architecture/graph.py
def __hash__(self):
    """
    As it is very complicated to compare graphs (i.e. check all edge
    attributes, do the have shared attributes, ...) use just the name
    for comparison.

    This is used when determining whether two instances are copies.
    """
    h = 0
    h += hash(self.name)
    h += hash(self.scope) if self.scope else 0
    h += hash(self._id)
    return h

add_edges_densly #

add_edges_densly()

Adds edges to get a fully connected DAG without cycles

Source code in neps/search_spaces/architecture/graph.py
def add_edges_densly(self):
    """
    Adds edges to get a fully connected DAG without cycles
    """
    self.add_edges_from(self.get_dense_edges())

add_node #

add_node(node_index, **attr)

Adds a node to the graph.

Note that adding a node using an index that has been used already will override its attributes.

PARAMETER DESCRIPTION
node_index

The index for the node. Expect to be >= 1.

TYPE: int

**attr

The attributes which can be added in a dict like form.

DEFAULT: {}

Source code in neps/search_spaces/architecture/graph.py
def add_node(self, node_index, **attr):
    """
    Adds a node to the graph.

    Note that adding a node using an index that has been used already
    will override its attributes.

    Args:
        node_index (int): The index for the node. Expect to be >= 1.
        **attr: The attributes which can be added in a dict like form.
    """
    assert node_index >= 1, "Expecting the node index to be greater or equal 1"
    nx.DiGraph.add_node(self, node_index, **attr)

assemble_trees #

assemble_trees(
    base_tree: str | DiGraph,
    motif_trees: list[str] | list[DiGraph],
    terminal_to_sublanguage_map: dict = None,
    node_label: str = "op_name",
) -> str | DiGraph

Assembles the base parse tree with the motif parse trees

PARAMETER DESCRIPTION
base_tree

Base parse tree

TYPE: DiGraph

motif_trees

List of motif parse trees

TYPE: List[DiGraph]

node_label

node label key. Defaults to "op_name".

TYPE: str DEFAULT: 'op_name'

RETURNS DESCRIPTION
str | DiGraph

nx.DiGraph: Assembled parse tree

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def assemble_trees(
    self,
    base_tree: str | nx.DiGraph,
    motif_trees: list[str] | list[nx.DiGraph],
    terminal_to_sublanguage_map: dict = None,
    node_label: str = "op_name",
) -> str | nx.DiGraph:
    """Assembles the base parse tree with the motif parse trees

    Args:
        base_tree (nx.DiGraph): Base parse tree
        motif_trees (List[nx.DiGraph]): List of motif parse trees
        node_label (str, optional): node label key. Defaults to "op_name".

    Returns:
        nx.DiGraph: Assembled parse tree
    """
    if not all([isinstance(base_tree, type(tree)) for tree in motif_trees]):
        raise ValueError("All trees must be of the same type!")
    if isinstance(base_tree, str):
        ensembled_tree_string = base_tree
        if terminal_to_sublanguage_map is None:
            raise NotImplementedError

        for motif, replacement in zip(
            terminal_to_sublanguage_map.keys(), motif_trees
        ):
            if motif in ensembled_tree_string:
                ensembled_tree_string = ensembled_tree_string.replace(
                    motif, replacement
                )
        return ensembled_tree_string
    elif isinstance(base_tree, nx.DiGraph):
        raise NotImplementedError
        leafnodes = self._find_leafnodes(base_tree)
        root_nodes = [self._find_root(G) for G in motif_trees]
        root_op_names = np.array(
            [
                motif_tree.nodes[root_node][node_label]
                for motif_tree, root_node in zip(motif_trees, root_nodes)
            ]
        )
        largest_node_number = max(base_tree.nodes())
        # ensembled_tree = base_tree.copy()
        # recreation is slightly faster
        ensembled_tree: nx.DiGraph = nx.DiGraph()
        ensembled_tree.add_nodes_from(base_tree.nodes(data=True))
        ensembled_tree.add_edges_from(base_tree.edges())
        for leafnode in leafnodes:
            idx = np.where(base_tree.nodes[leafnode][node_label] == root_op_names)[0]
            if len(idx) == 0:
                continue
            if len(idx) > 1:
                raise ValueError(
                    "More than two similar terminal/start symbols are not supported!"
                )

            tree = motif_trees[idx[0]]
            # generate mapping
            mapping = {
                n: n_new
                for n, n_new in zip(
                    tree.nodes(),
                    range(
                        largest_node_number + 1,
                        largest_node_number + 1 + len(tree),
                    ),
                )
            }
            largest_node_number = largest_node_number + 1 + len(tree)
            tree_relabeled = self._relabel_nodes(G=tree, mapping=mapping)

            # compose trees
            predecessor_in_base_tree = list(ensembled_tree.pred[leafnode])[0]
            motif_tree_root_node = self._find_root(tree_relabeled)
            successors_in_motif_tree = tree_relabeled.nodes[motif_tree_root_node][
                "children"
            ]

            # delete unnecessary edges
            ensembled_tree.remove_node(leafnode)
            tree_relabeled.remove_node(motif_tree_root_node)
            # add new edges
            tree_relabeled.add_node(predecessor_in_base_tree)
            for n in successors_in_motif_tree:
                tree_relabeled.add_edge(predecessor_in_base_tree, n)

            ensembled_tree.update(
                edges=tree_relabeled.edges(data=True),
                nodes=tree_relabeled.nodes(data=True),
            )

            idx = np.where(
                np.array(ensembled_tree.nodes[predecessor_in_base_tree]["children"])
                == leafnode
            )[0][0]
            old_children = ensembled_tree.nodes[predecessor_in_base_tree]["children"]
            ensembled_tree.nodes[predecessor_in_base_tree]["children"] = (
                old_children[: idx + 1]
                + successors_in_motif_tree
                + old_children[idx + 1 :]
            )
            ensembled_tree.nodes[predecessor_in_base_tree]["children"].remove(
                leafnode
            )
        return ensembled_tree
    else:
        raise NotImplementedError(
            f"Assembling of trees of type {type(base_tree)} is not supported!"
        )

build_graph_from_tree #

build_graph_from_tree(
    tree: DiGraph,
    terminal_to_torch_map: dict,
    node_label: str = "op_name",
    flatten_graph: bool = True,
    return_cell: bool = False,
) -> None | Graph

Builds the computational graph from a parse tree.

PARAMETER DESCRIPTION
tree

parse tree.

TYPE: DiGraph

terminal_to_torch_map

Mapping from terminal symbols to primitives or topologies.

TYPE: dict

node_label

Key to access terminal symbol. Defaults to "op_name".

TYPE: str DEFAULT: 'op_name'

return_cell

Whether to return a cell. Is only needed if cell is repeated multiple times.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
None | Graph

Tuple[Union[None, Graph]]: computational graph (self) or cell.

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def build_graph_from_tree(
    self,
    tree: nx.DiGraph,
    terminal_to_torch_map: dict,
    node_label: str = "op_name",
    flatten_graph: bool = True,
    return_cell: bool = False,
) -> None | Graph:
    """Builds the computational graph from a parse tree.

    Args:
        tree (nx.DiGraph): parse tree.
        terminal_to_torch_map (dict): Mapping from terminal symbols to primitives or topologies.
        node_label (str, optional): Key to access terminal symbol. Defaults to "op_name".
        return_cell (bool, optional): Whether to return a cell. Is only needed if cell is repeated multiple times.
        Defaults to False.

    Returns:
        Tuple[Union[None, Graph]]: computational graph (self) or cell.
    """

    def _build_graph_from_tree(
        visited: set,
        tree: nx.DiGraph,
        node: int,
        terminal_to_torch_map: dict,
        node_label: str,
        is_primitive: bool = False,
    ):
        """Recursive DFS-esque function to build computational graph from parse tree

        Args:
            visited (set): set of visited nodes.
            tree (nx.DiGraph): parse tree.
            node (int): node index.
            terminal_to_torch_map (dict): mapping from terminal symbols to primitives or topologies.
            node_label (str): key to access operation name

        Raises:
            Exception: primitive or topology is unknown, i.e., it is probably missing in the terminal to
            torch mapping
            Exception: leftmost children can only be primitive, topology or have one child

        Returns:
            [type]: computational graph.
        """
        if node not in visited:
            subgraphs = []
            primitive_hps = []
            if len(tree.out_edges(node)) == 0:
                if is_primitive:
                    return tree.nodes[node][node_label]
                else:
                    if (
                        tree.nodes[node][node_label]
                        not in terminal_to_torch_map.keys()
                    ):
                        raise Exception(
                            f"Unknown primitive or topology: {tree.nodes[node][node_label]}"
                        )
                    return deepcopy(
                        terminal_to_torch_map[tree.nodes[node][node_label]]
                    )
            if len(tree.out_edges(node)) == 1:
                return _build_graph_from_tree(
                    visited,
                    tree,
                    list(tree.neighbors(node))[0],
                    terminal_to_torch_map,
                    node_label,
                    is_primitive,
                )
            # for idx, neighbor in enumerate(tree.neighbors(node)):
            for idx, neighbor in enumerate(
                self._get_neighbors_from_parse_tree(tree, node)
            ):
                if idx == 0:  # topology or primitive
                    n = neighbor
                    while not tree.nodes[n]["terminal"]:
                        if len(tree.out_edges(n)) != 1:
                            raise Exception(
                                "Leftmost Child can only be primitive, topology or recursively have one child!"
                            )
                        n = next(tree.neighbors(n))
                    if is_primitive:
                        primitive_hp_key = tree.nodes[n][node_label]
                        primitive_hp_dict = {primitive_hp_key: None}
                        is_primitive_op = True
                    else:
                        if (
                            tree.nodes[n][node_label]
                            not in terminal_to_torch_map.keys()
                        ):
                            raise Exception(
                                f"Unknown primitive or topology: {tree.nodes[n][node_label]}"
                            )
                        graph_el = terminal_to_torch_map[tree.nodes[n][node_label]]
                        is_primitive_op = issubclass(
                            graph_el.func
                            if isinstance(graph_el, partial)
                            else graph_el,
                            AbstractPrimitive,
                        )
                elif not tree.nodes[neighbor][
                    "terminal"
                ]:  # exclude '[' ']' ... symbols
                    if is_primitive:
                        primitive_hp_dict[primitive_hp_key] = _build_graph_from_tree(
                            visited,
                            tree,
                            neighbor,
                            terminal_to_torch_map,
                            node_label,
                            is_primitive_op,
                        )
                    elif is_primitive_op:
                        primitive_hps.append(
                            _build_graph_from_tree(
                                visited,
                                tree,
                                neighbor,
                                terminal_to_torch_map,
                                node_label,
                                is_primitive_op,
                            )
                        )
                    else:
                        subgraphs.append(
                            _build_graph_from_tree(
                                visited,
                                tree,
                                neighbor,
                                terminal_to_torch_map,
                                node_label,
                                is_primitive_op,
                            )
                        )
                elif (
                    tree.nodes[neighbor][node_label] in terminal_to_torch_map.keys()
                ):  # exclude '[' ']' ... symbols
                    # TODO check if there is a potential bug here?
                    subgraphs.append(
                        deepcopy(
                            terminal_to_torch_map[tree.nodes[neighbor][node_label]]
                        )
                    )

            if is_primitive:
                return primitive_hp_dict
            elif is_primitive_op:
                return dict(
                    collections.ChainMap(*([{"op": graph_el}] + primitive_hps))
                )
            else:
                return graph_el(*subgraphs)

    def _flatten_graph(
        graph,
        flattened_graph,
        start_node: int = None,
        end_node: int = None,
    ):
        nodes: dict = {}
        for u, v, data in graph.edges(data=True):
            if u in nodes.keys():
                _u = nodes[u]
            else:
                _u = (
                    1
                    if len(flattened_graph.nodes.keys()) == 0
                    else max(flattened_graph.nodes.keys()) + 1
                )
                _u = (
                    start_node
                    if graph.in_degree(u) == 0 and start_node is not None
                    else _u
                )
                nodes[u] = _u
                if _u not in flattened_graph.nodes.keys():
                    flattened_graph.add_node(_u)

            if v in nodes.keys():
                _v = nodes[v]
            else:
                _v = max(flattened_graph.nodes.keys()) + 1
                _v = (
                    end_node
                    if graph.out_degree(v) == 0 and end_node is not None
                    else _v
                )
                nodes[v] = _v
                if _v not in flattened_graph.nodes.keys():
                    flattened_graph.add_node(_v)

            if isinstance(data["op"], Graph):
                flattened_graph = _flatten_graph(
                    data["op"], flattened_graph, start_node=_u, end_node=_v
                )
            else:
                flattened_graph.add_edge(_u, _v)
                flattened_graph.edges[_u, _v].update(data)

        return flattened_graph

    root_node = self._find_root(tree)
    graph = _build_graph_from_tree(
        set(), tree, root_node, terminal_to_torch_map, node_label
    )
    self._check_graph(graph)
    if return_cell:
        cell = (
            _flatten_graph(graph, flattened_graph=Graph()) if flatten_graph else graph
        )
        return cell
    else:
        if flatten_graph:
            _flatten_graph(graph, flattened_graph=self)
        else:
            self.add_edge(0, 1)
            self.edges[0, 1].set("op", graph)
        return None

clone #

clone()

Deep copy of the current graph.

RETURNS DESCRIPTION
Graph

Deep copy of the graph.

Source code in neps/search_spaces/architecture/graph.py
def clone(self):
    """
    Deep copy of the current graph.

    Returns:
        Graph: Deep copy of the graph.
    """
    return copy.deepcopy(self)

compile #

compile()

Instanciates the ops at the edges using the arguments specified at the edges

Source code in neps/search_spaces/architecture/graph.py
def compile(self):
    """
    Instanciates the ops at the edges using the arguments specified at the edges
    """
    for graph in self._get_child_graphs(single_instances=False) + [self]:
        logger.debug(f"Compiling graph {graph.name}")
        for _, v, edge_data in graph.edges.data():
            if not edge_data.is_final():
                attr = edge_data.to_dict()
                op = attr.pop("op")

                if isinstance(op, list):
                    compiled_ops = []
                    for i, o in enumerate(op):
                        if inspect.isclass(o):
                            # get the relevant parameter if there are more.
                            a = {
                                k: v[i] if isinstance(v, list) else v
                                for k, v in attr.items()
                            }
                            compiled_ops.append(o(**a))
                        else:
                            logger.debug(f"op {o} already compiled. Skipping")
                    edge_data.set("op", compiled_ops)
                elif isinstance(op, AbstractPrimitive):
                    logger.debug(f"op {op} already compiled. Skipping")
                elif inspect.isclass(op) and issubclass(op, AbstractPrimitive):
                    # Init the class
                    if "op_name" in attr:
                        del attr["op_name"]
                    edge_data.set("op", op(**attr))
                elif isinstance(op, Graph):
                    pass  # This is already covered by _get_child_graphs
                else:
                    raise ValueError(f"Unkown format of op: {op}")

copy #

copy()

Copy as defined in networkx, i.e. a shallow copy.

Just handling recursively nested graphs seperately.

Source code in neps/search_spaces/architecture/graph.py
def copy(self):
    """
    Copy as defined in networkx, i.e. a shallow copy.

    Just handling recursively nested graphs seperately.
    """

    def copy_dict(d):
        copied_dict = d.copy()
        for k, v in d.items():
            if isinstance(v, Graph):
                copied_dict[k] = v.copy()
            elif isinstance(v, list):
                copied_dict[k] = [
                    i.copy() if isinstance(i, Graph) else i for i in v
                ]
            elif isinstance(v, torch.nn.Module) or isinstance(v, AbstractPrimitive):
                copied_dict[k] = copy.deepcopy(v)
        return copied_dict

    G = self.__class__()
    G.graph.update(self.graph)
    G.add_nodes_from((n, copy_dict(d)) for n, d in self._node.items())
    G.add_edges_from(
        (u, v, datadict.copy())
        for u, nbrs in self._adj.items()
        for v, datadict in nbrs.items()
    )
    G.scope = self.scope
    G.name = self.name
    return G

forward #

forward(x, *args)

Forward some data through the graph. This is done recursively in case there are graphs defined on nodes or as 'op' on edges.

PARAMETER DESCRIPTION
x

The input. If the graph sits on a node the input can be a dict with {source_idx: Tensor} to be routed to the defined input nodes. If the graph sits on an edge, x is the feature tensor.

TYPE: Tensor or dict

args

This is only required to handle cases where the graph sits on an edge and receives an EdgeData object which will be ignored

DEFAULT: ()

Source code in neps/search_spaces/architecture/graph.py
def forward(self, x, *args):
    """
    Forward some data through the graph. This is done recursively
    in case there are graphs defined on nodes or as 'op' on edges.

    Args:
        x (Tensor or dict): The input. If the graph sits on a node the
            input can be a dict with {source_idx: Tensor} to be routed
            to the defined input nodes. If the graph sits on an edge,
            x is the feature tensor.
        args: This is only required to handle cases where the graph sits
            on an edge and receives an EdgeData object which will be ignored
    """
    logger.debug(f"Graph {self.name} called. Input {log_formats(x)}.")

    # Assign x to the corresponding input nodes
    self._assign_x_to_nodes(x)

    for node_idx in lexicographical_topological_sort(self):
        node = self.nodes[node_idx]
        logger.debug(
            "Node {}-{}, current data {}, start processing...".format(
                self.name, node_idx, log_formats(node)
            )
        )

        # node internal: process input if necessary
        if ("subgraph" in node and "comb_op" not in node) or (
            "comb_op" in node and "subgraph" not in node
        ):
            log_first_n(
                logging.WARN, "Comb_op is ignored if subgraph is defined!", n=1
            )
        # TODO: merge 'subgraph' and 'comb_op'. It is basicallly the same thing. Also in parse()
        if "subgraph" in node:
            x = node["subgraph"].forward(node["input"])
        else:
            if len(node["input"].values()) == 1:
                x = list(node["input"].values())[0]
            else:
                x = node["comb_op"](
                    [node["input"][k] for k in sorted(node["input"].keys())]
                )
        node["input"] = {}  # clear the input as we have processed it

        if (
            len(list(self.neighbors(node_idx))) == 0
            and node_idx < list(lexicographical_topological_sort(self))[-1]
        ):
            # We have more than one output node. This is e.g. the case for
            # auxillary losses. Attach them to the graph, handling must done
            # by the user.
            logger.debug(
                "Graph {} has more then one output node. Storing output of non-maximum index node {} at graph dict".format(
                    self, node_idx
                )
            )
            self.graph[f"out_from_{node_idx}"] = x
        else:
            # outgoing edges: process all outgoing edges
            for neigbor_idx in self.neighbors(node_idx):
                edge_data = self.get_edge_data(node_idx, neigbor_idx)
                # inject edge data only for AbstractPrimitive, not Graphs
                if isinstance(edge_data.op, Graph):
                    edge_output = edge_data.op.forward(x)
                elif isinstance(edge_data.op, AbstractPrimitive):
                    logger.debug(
                        "Processing op {} at edge {}-{}".format(
                            edge_data.op, node_idx, neigbor_idx
                        )
                    )
                    edge_output = edge_data.op.forward(x)
                else:
                    raise ValueError(
                        "Unknown class as op: {}. Expected either Graph or AbstactPrimitive".format(
                            edge_data.op
                        )
                    )
                self.nodes[neigbor_idx]["input"].update({node_idx: edge_output})

        logger.debug(f"Node {self.name}-{node_idx}, processing done.")

    logger.debug(f"Graph {self.name} exiting. Output {log_formats(x)}.")
    return x

from_nxTree_to_stringTree #

from_nxTree_to_stringTree(
    nxTree: DiGraph, node_label: str = "op_name"
) -> str

Transforms parse tree represented as NetworkX DAG to string representation.

PARAMETER DESCRIPTION
nxTree

parse tree.

TYPE: DiGraph

node_label

key to access operation names. Defaults to "op_name".

TYPE: str DEFAULT: 'op_name'

RETURNS DESCRIPTION
str

parse tree represented as string.

TYPE: str

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def from_nxTree_to_stringTree(
    self, nxTree: nx.DiGraph, node_label: str = "op_name"
) -> str:
    """Transforms parse tree represented as NetworkX DAG to string representation.

    Args:
        nxTree (nx.DiGraph): parse tree.
        node_label (str, optional): key to access operation names. Defaults to "op_name".

    Returns:
        str: parse tree represented as string.
    """

    def dfs(visited, graph, node):
        if node not in visited:
            visited.add(node)
            if graph.nodes[node]["terminal"]:
                return f"{graph.nodes[node][node_label]}"
            tmp_str = f"{f'({graph.nodes[node][node_label]}'}" + " "
            # for neighbor in graph.neighbors(node):
            for neighbor in self._get_neighbors_from_parse_tree(graph, node):
                tmp_str += dfs(visited, graph, neighbor) + " "
            tmp_str = tmp_str[:-1] + ")"
            return tmp_str
        return ""

    return dfs(set(), nxTree, node=self._find_root(nxTree))

from_stringTree_to_graph_repr #

from_stringTree_to_graph_repr(
    string_tree: str,
    grammar: Grammar,
    valid_terminals: KeysView,
    edge_attr: bool = True,
    sym_name: str = "op_name",
    prune: bool = True,
    add_subtree_map: bool = False,
    return_all_subgraphs: bool = None,
    return_graph_per_hierarchy: bool = None,
) -> DiGraph | tuple[DiGraph, OrderedDict]

Generates graph from parse tree in string representation. Note that we ignore primitive HPs!

PARAMETER DESCRIPTION
string_tree

parse tree.

TYPE: str

grammar

underlying grammar.

TYPE: Grammar

valid_terminals

list of keys.

TYPE: list

edge_attr

Shoud graph be edge attributed (True) or node attributed (False). Defaults to True.

TYPE: bool DEFAULT: True

sym_name

Attribute name of operation. Defaults to "op_name".

TYPE: str DEFAULT: 'op_name'

prune

Prune graph, e.g., None operations etc. Defaults to True.

TYPE: bool DEFAULT: True

add_subtree_map

Add attribute indicating to which subtrees of the parse tree the specific part belongs to. Can only be true if you set prune=False! TODO: Check if we really need this constraint or can also allow pruning. Defaults to False.

TYPE: bool DEFAULT: False

return_all_subgraphs

Additionally returns an hierarchical dictionary containing all subgraphs. Defaults to False. TODO: check if edge attr also works.

TYPE: bool DEFAULT: None

return_graph_per_hierarchy

Additionally returns a graph from each each hierarchy.

TYPE: bool DEFAULT: None

RETURNS DESCRIPTION
DiGraph | tuple[DiGraph, OrderedDict]

nx.DiGraph: [description]

Source code in neps/search_spaces/architecture/core_graph_grammar.py
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
def from_stringTree_to_graph_repr(
    self,
    string_tree: str,
    grammar: Grammar,
    valid_terminals: collections.abc.KeysView,
    edge_attr: bool = True,
    sym_name: str = "op_name",
    prune: bool = True,
    add_subtree_map: bool = False,
    return_all_subgraphs: bool = None,
    return_graph_per_hierarchy: bool = None,
) -> nx.DiGraph | tuple[nx.DiGraph, collections.OrderedDict]:
    """Generates graph from parse tree in string representation.
    Note that we ignore primitive HPs!

    Args:
        string_tree (str): parse tree.
        grammar (Grammar): underlying grammar.
        valid_terminals (list): list of keys.
        edge_attr (bool, optional): Shoud graph be edge attributed (True) or node attributed (False). Defaults to True.
        sym_name (str, optional): Attribute name of operation. Defaults to "op_name".
        prune (bool, optional): Prune graph, e.g., None operations etc. Defaults to True.
        add_subtree_map (bool, optional): Add attribute indicating to which subtrees of
            the parse tree the specific part belongs to. Can only be true if you set prune=False!
            TODO: Check if we really need this constraint or can also allow pruning. Defaults to False.
        return_all_subgraphs (bool, optional): Additionally returns an hierarchical dictionary
            containing all subgraphs. Defaults to False.
            TODO: check if edge attr also works.
        return_graph_per_hierarchy (bool, optional): Additionally returns a graph from each
            each hierarchy.

    Returns:
        nx.DiGraph: [description]
    """

    def get_node_labels(graph: nx.DiGraph):
        return [
            (n, d[sym_name])
            for n, d in graph.nodes(data=True)
            if d[sym_name] != "input" and d[sym_name] != "output"
        ]

    def get_hierarchicy_dict(
        string_tree: str,
        subgraphs: dict,
        hierarchy_dict: dict = None,
        hierarchy_level_counter: int = 0,
    ):
        if hierarchy_dict is None:
            hierarchy_dict = {}
        if hierarchy_level_counter not in hierarchy_dict.keys():
            hierarchy_dict[hierarchy_level_counter] = []
        hierarchy_dict[hierarchy_level_counter].append(string_tree)
        node_labels = get_node_labels(subgraphs[string_tree])
        for _, node_label in node_labels:
            if node_label in subgraphs.keys():
                hierarchy_dict = get_hierarchicy_dict(
                    node_label, subgraphs, hierarchy_dict, hierarchy_level_counter + 1
                )
        return hierarchy_dict

    def get_graph_per_hierarchy(string_tree: str, subgraphs: dict):
        hierarchy_dict = get_hierarchicy_dict(
            string_tree=string_tree, subgraphs=subgraphs
        )

        graph_per_hierarchy = collections.OrderedDict()
        for k, v in hierarchy_dict.items():
            if k == 0:
                graph_per_hierarchy[k] = subgraphs[v[0]]
            else:
                subgraph_ = graph_per_hierarchy[k - 1].copy()
                node_labels = get_node_labels(subgraph_)
                for node, node_label in node_labels:
                    if node_label in list(subgraphs.keys()):
                        in_nodes = list(subgraph_.predecessors(node))
                        out_nodes = list(subgraph_.successors(node))
                        node_offset = max(subgraph_.nodes) + 1

                        new_subgraph = nx.relabel.relabel_nodes(
                            subgraphs[node_label],
                            mapping={
                                n: n + node_offset
                                for n in subgraphs[node_label].nodes
                            },
                            copy=True,
                        )
                        first_nodes = {e[0] for e in new_subgraph.edges}
                        second_nodes = {e[1] for e in new_subgraph.edges}
                        (begin_node,) = first_nodes - second_nodes
                        (end_node,) = second_nodes - first_nodes
                        successors = list(new_subgraph.successors(begin_node))
                        predecessors = list(new_subgraph.predecessors(end_node))
                        new_subgraph.remove_nodes_from([begin_node, end_node])
                        edges = []
                        added_identities = False
                        for in_node in in_nodes:
                            for succ in successors:
                                if succ == end_node:
                                    if not added_identities:
                                        edges.extend(
                                            [
                                                (inn, onn)
                                                for inn in in_nodes
                                                for onn in out_nodes
                                            ]
                                        )
                                    added_identities = True
                                else:
                                    edges.append((in_node, succ))
                        for out_node in out_nodes:
                            for pred in predecessors:
                                if pred != begin_node:
                                    edges.append((pred, out_node))

                        subgraph_ = nx.compose(new_subgraph, subgraph_)
                        subgraph_.add_edges_from(edges)
                        subgraph_.remove_node(node)

                graph_per_hierarchy[k] = subgraph_
        return graph_per_hierarchy

    def to_node_attributed_edge_list(
        edge_list: list[tuple],
    ) -> tuple[list[tuple[int, int]], dict]:
        node_offset = 2
        edge_to_node_map = {e: i + node_offset for i, e in enumerate(edge_list)}
        first_nodes = {e[0] for e in edge_list}
        second_nodes = {e[1] for e in edge_list}
        (src,) = first_nodes - second_nodes
        (tgt,) = second_nodes - first_nodes
        node_list = []
        for e in edge_list:
            ni = edge_to_node_map[e]
            u, v = e
            if u == src:
                node_list.append((0, ni))
            if v == tgt:
                node_list.append((ni, 1))

            for e_ in filter(
                lambda e: (e[1] == u), edge_list
            ):
                node_list.append((edge_to_node_map[e_], ni))

        return node_list, edge_to_node_map

    def skip_char(char: str) -> bool:
        return True if char in [" ", "\t", "\n", "[", "]"] else False

    if prune:
        add_subtree_map = False

    if return_all_subgraphs is None:
        return_all_subgraphs = self.return_all_subgraphs
    if return_graph_per_hierarchy is None:
        return_graph_per_hierarchy = self.return_graph_per_hierarchy
    compute_subgraphs = return_all_subgraphs or return_graph_per_hierarchy

    G = nx.DiGraph()
    if add_subtree_map:
        q_nonterminals: Deque = collections.deque()
    if compute_subgraphs:
        q_subtrees: Deque = collections.deque()
        q_subgraphs: Deque = collections.deque()
        subgraphs_dict = collections.OrderedDict()
    if edge_attr:
        node_offset = 0
        q_el: Deque = collections.deque()  # edge-attr
        terminal_to_graph = self.terminal_to_graph_edges
    else:  # node-attributed
        G.add_node(0, **{sym_name: "input"})
        G.add_node(1, **{sym_name: "output"})
        node_offset = 2
        if bool(self.terminal_to_graph_nodes):
            terminal_to_graph_nodes = self.terminal_to_graph_nodes
        else:
            terminal_to_graph_nodes = {
                k: to_node_attributed_edge_list(edge_list) if edge_list else []
                for k, edge_list in self.terminal_to_graph_edges.items()
            }
            self.terminal_to_graph_nodes = terminal_to_graph_nodes
        terminal_to_graph = {
            k: v[0] if v else [] for k, v in terminal_to_graph_nodes.items()
        }
        q_el = collections.deque()  # node-attr

    # pre-compute stuff
    begin_end_nodes = {}
    for sym, g in terminal_to_graph.items():
        if g:
            first_nodes = {e[0] for e in g}
            second_nodes = {e[1] for e in g}
            (begin_node,) = first_nodes - second_nodes
            (end_node,) = second_nodes - first_nodes
            begin_end_nodes[sym] = (begin_node, end_node)
        else:
            begin_end_nodes[sym] = (None, None)

    for split_idx, sym in enumerate(string_tree.split(" ")):
        is_nonterminal = False
        if sym == "":
            continue
        if compute_subgraphs:
            new_sym = True
            sym_copy = sym[:]
        if sym[0] == "(":
            sym = sym[1:]
            is_nonterminal = True
        if sym[-1] == ")":
            if add_subtree_map:
                for _ in range(sym.count(")")):
                    q_nonterminals.pop()
            if compute_subgraphs:
                new_sym = False
            while sym[-1] == ")" and sym not in valid_terminals:
                sym = sym[:-1]

        if compute_subgraphs and new_sym:
            if sym in grammar.nonterminals:
                # need dict as a graph can have multiple subgraphs
                q_subtrees.append(sym_copy[:])
            else:
                q_subtrees[-1] += f" {sym_copy}"

        if len(sym) == 1 and skip_char(sym[0]):
            continue

        if add_subtree_map and sym in grammar.nonterminals:
            q_nonterminals.append((sym, split_idx))
        elif sym in valid_terminals and not is_nonterminal:  # terminal symbol
            if sym in self.terminal_to_graph_edges:
                if len(q_el) == 0:
                    if edge_attr:
                        edges = [
                            tuple(t + node_offset for t in e)
                            for e in self.terminal_to_graph_edges[sym]
                        ]
                    else:  # node-attr
                        edges = [
                            tuple(t for t in e)
                            for e in terminal_to_graph_nodes[sym][0]
                        ]
                        nodes = [
                            terminal_to_graph_nodes[sym][1][e]
                            for e in self.terminal_to_graph_edges[sym]
                        ]
                    if add_subtree_map:
                        subtrees = []
                    first_nodes = {e[0] for e in edges}
                    second_nodes = {e[1] for e in edges}
                    (src_node,) = first_nodes - second_nodes
                    (sink_node,) = second_nodes - first_nodes
                else:
                    begin_node, end_node = begin_end_nodes[sym]
                    el = q_el.pop()
                    if edge_attr:
                        u, v = el
                        if add_subtree_map:
                            subtrees = G[u][v]["subtrees"]
                        G.remove_edge(u, v)
                        edges = [
                            tuple(
                                u
                                if t == begin_node
                                else v
                                if t == end_node
                                else t + node_offset
                                for t in e
                            )
                            for e in self.terminal_to_graph_edges[sym]
                        ]
                    else:  # node-attr
                        n = el
                        if add_subtree_map:
                            subtrees = G.nodes[n]["subtrees"]
                        in_nodes = list(G.predecessors(n))
                        out_nodes = list(G.successors(n))
                        G.remove_node(n)
                        edges = []
                        for e in terminal_to_graph_nodes[sym][0]:
                            if not (e[0] == begin_node or e[1] == end_node):
                                edges.append((e[0] + node_offset, e[1] + node_offset))
                            elif e[0] == begin_node:
                                for nin in in_nodes:
                                    edges.append((nin, e[1] + node_offset))
                            elif e[1] == end_node:
                                for nout in out_nodes:
                                    edges.append((e[0] + node_offset, nout))
                        nodes = [
                            terminal_to_graph_nodes[sym][1][e] + node_offset
                            for e in self.terminal_to_graph_edges[sym]
                        ]

                G.add_edges_from(edges)

                if compute_subgraphs:
                    subgraph = nx.DiGraph()
                    subgraph.add_edges_from(edges)
                    q_subgraphs.append(
                        {
                            "graph": subgraph,
                            "atoms": collections.OrderedDict(
                                (atom, None)
                                for atom in (edges if edge_attr else nodes)
                            ),
                        }
                    )

                if add_subtree_map:
                    if edge_attr:
                        subtrees.append(q_nonterminals[-1])
                        for u, v in edges:
                            G[u][v]["subtrees"] = subtrees.copy()
                    else:  # node-attr
                        subtrees.append(q_nonterminals[-1])
                        for n in nodes:
                            G.nodes[n]["subtrees"] = subtrees.copy()

                q_el.extend(reversed(edges if edge_attr else nodes))
                if edge_attr:
                    node_offset += max(max(self.terminal_to_graph_edges[sym]))
                else:
                    node_offset += max(terminal_to_graph_nodes[sym][1].values())
            else:  # primitive operations
                el = q_el.pop()
                if edge_attr:
                    u, v = el
                    if prune and sym in self.zero_op:
                        G.remove_edge(u, v)
                        if compute_subgraphs:
                            q_subgraphs[-1]["graph"].remove_edge(u, v)
                            del q_subgraphs[-1]["atoms"][(u, v)]
                    else:
                        G[u][v][sym_name] = sym
                        if compute_subgraphs:
                            q_subgraphs[-1]["graph"][u][v][sym_name] = sym
                        if add_subtree_map:
                            G[u][v]["subtrees"].append(q_nonterminals[-1])
                            q_nonterminals.pop()
                else:  # node-attr
                    n = el
                    if prune and sym in self.zero_op:
                        G.remove_node(n)
                        if compute_subgraphs:
                            q_subgraphs[-1]["graph"].remove_node(n)
                            del q_subgraphs[-1]["atoms"][n]
                    elif prune and sym in self.identity_op:
                        G.add_edges_from(
                            [
                                (n_in, n_out)
                                for n_in in G.predecessors(n)
                                for n_out in G.successors(n)
                            ]
                        )
                        G.remove_node(n)
                        if compute_subgraphs:
                            q_subgraphs[-1]["graph"].add_edges_from(
                                [
                                    (n_in, n_out)
                                    for n_in in q_subgraphs[-1]["graph"].predecessors(
                                        n
                                    )
                                    for n_out in q_subgraphs[-1]["graph"].successors(
                                        n
                                    )
                                ]
                            )
                            q_subgraphs[-1]["graph"].remove_node(n)
                            del q_subgraphs[-1]["atoms"][n]
                    else:
                        G.nodes[n][sym_name] = sym
                        if compute_subgraphs:
                            q_subgraphs[-1]["graph"].nodes[n][sym_name] = sym
                            q_subgraphs[-1]["atoms"][
                                next(
                                    filter(
                                        lambda x: x[1] is None,
                                        q_subgraphs[-1]["atoms"].items(),
                                    )
                                )[0]
                            ] = sym
                        if add_subtree_map:
                            G.nodes[n]["subtrees"].append(q_nonterminals[-1])
                            q_nonterminals.pop()
        if compute_subgraphs and sym_copy[-1] == ")":
            q_subtrees[-1] += f" {sym_copy}"
            for _ in range(sym_copy.count(")")):
                subtree_identifier = q_subtrees.pop()
                if len(q_subtrees) > 0:
                    q_subtrees[-1] += f" {subtree_identifier}"
                if len(q_subtrees) == len(q_subgraphs) - 1:
                    difference = subtree_identifier.count(
                        "("
                    ) - subtree_identifier.count(")")
                    if difference < 0:
                        subtree_identifier = subtree_identifier[:difference]
                    subgraph_dict = q_subgraphs.pop()
                    subgraph = subgraph_dict["graph"]
                    atoms = subgraph_dict["atoms"]
                    if len(q_subtrees) > 0:
                        # subtree_identifier is subgraph graph at [-1]
                        # (and sub-...-subgraph currently in q_subgraphs)
                        q_subgraphs[-1]["atoms"][
                            next(
                                filter(
                                    lambda x: x[1] is None,
                                    q_subgraphs[-1]["atoms"].items(),
                                )
                            )[0]
                        ] = subtree_identifier

                    for atom in filter(lambda x: x[1] is not None, atoms.items()):
                        if edge_attr:
                            subgraph[atom[0][0]][atom[0][1]][sym_name] = atom[1]
                        else:  # node-attr
                            subgraph.nodes[atom[0]][sym_name] = atom[1]

                    if not edge_attr:  # node-attr
                        # ensure there is actually one input and output node
                        first_nodes = {e[0] for e in subgraph.edges}
                        second_nodes = {e[1] for e in subgraph.edges}
                        new_src_node = max(subgraph.nodes) + 1
                        src_nodes = first_nodes - second_nodes
                        subgraph.add_edges_from(
                            [
                                (new_src_node, successor)
                                for src_node in src_nodes
                                for successor in subgraph.successors(src_node)
                            ]
                        )
                        subgraph.add_node(new_src_node, **{sym_name: "input"})
                        subgraph.remove_nodes_from(src_nodes)
                        new_sink_node = max(subgraph.nodes) + 1
                        sink_nodes = second_nodes - first_nodes
                        subgraph.add_edges_from(
                            [
                                (predecessor, new_sink_node)
                                for sink_node in sink_nodes
                                for predecessor in subgraph.predecessors(sink_node)
                            ]
                        )
                        subgraph.add_node(new_sink_node, **{sym_name: "output"})
                        subgraph.remove_nodes_from(sink_nodes)
                    subgraphs_dict[subtree_identifier] = subgraph

    if len(q_el) != 0:
        raise Exception("Invalid string_tree")

    if prune:
        G = self.prune_unconnected_parts(G, src_node, sink_node)
    self._check_graph(G)

    if return_all_subgraphs or return_graph_per_hierarchy:
        return_val = [G]
        subgraphs_dict = collections.OrderedDict(
            reversed(list(subgraphs_dict.items()))
        )
        if prune:
            for v in subgraphs_dict.values():
                first_nodes = {e[0] for e in v.edges}
                second_nodes = {e[1] for e in v.edges}
                (vG_src_node,) = first_nodes - second_nodes
                (vG_sink_node,) = second_nodes - first_nodes
                v = self.prune_unconnected_parts(v, vG_src_node, vG_sink_node)
                self._check_graph(v)
        if return_all_subgraphs:
            return_val.append(subgraphs_dict)
        if return_graph_per_hierarchy:
            graph_per_hierarchy = get_graph_per_hierarchy(string_tree, subgraphs_dict)
            _ = (
                graph_per_hierarchy.popitem()
            )  # remove last graph since it is equal to full graph
            return_val.append(graph_per_hierarchy)
        return return_val
    return G

from_stringTree_to_nxTree staticmethod #

from_stringTree_to_nxTree(
    string_tree: str,
    grammar: Grammar,
    sym_name: str = "op_name",
) -> DiGraph

Transforms a parse tree from string representation to NetworkX representation.

PARAMETER DESCRIPTION
string_tree

parse tree.

TYPE: str

grammar

context-free grammar which generated the parse tree in string represenation.

TYPE: Grammar

sym_name

Key to save the terminal symbols. Defaults to "op_name".

TYPE: str DEFAULT: 'op_name'

RETURNS DESCRIPTION
DiGraph

nx.DiGraph: parse tree as NetworkX representation.

Source code in neps/search_spaces/architecture/core_graph_grammar.py
@staticmethod
def from_stringTree_to_nxTree(
    string_tree: str, grammar: Grammar, sym_name: str = "op_name"
) -> nx.DiGraph:
    """Transforms a parse tree from string representation to NetworkX representation.

    Args:
        string_tree (str): parse tree.
        grammar (Grammar): context-free grammar which generated the parse tree in string represenation.
        sym_name (str, optional): Key to save the terminal symbols. Defaults to "op_name".

    Returns:
        nx.DiGraph: parse tree as NetworkX representation.
    """

    def skip_char(char: str) -> bool:
        if char in [" ", "\t", "\n"]:
            return True
        # special case: "(" is (part of) a terminal
        if (
            i != 0
            and char == "("
            and string_tree[i - 1] == " "
            and string_tree[i + 1] == " "
        ):
            return False
        if char == "(":
            return True
        return False

    def find_longest_match(
        i: int, string_tree: str, symbols: list[str], max_match: int
    ) -> int:
        # search for longest matching symbol and add it
        # assumes that the longest match is the true match
        j = min(i + max_match, len(string_tree) - 1)
        while j > i and j < len(string_tree):
            if string_tree[i:j] in symbols:
                break
            j -= 1
        if j == i:
            raise Exception(f"Terminal or nonterminal at position {i} does not exist")
        return j

    if isinstance(grammar, list) and len(grammar) > 1:
        full_grammar = deepcopy(grammar[0])
        rules = full_grammar.productions()
        nonterminals = full_grammar.nonterminals
        terminals = full_grammar.terminals
        for g in grammar[1:]:
            rules.extend(g.productions())
            nonterminals.extend(g.nonterminals)
            terminals.extend(g.terminals)
        grammar = full_grammar
        raise NotImplementedError("TODO check implementation")

    symbols = grammar.nonterminals + grammar.terminals
    max_match = max(map(len, symbols))
    find_longest_match_func = partial(
        find_longest_match,
        string_tree=string_tree,
        symbols=symbols,
        max_match=max_match,
    )

    G = nx.DiGraph()
    q: queue.LifoQueue = queue.LifoQueue()
    q_children: queue.LifoQueue = queue.LifoQueue()
    node_number = 0
    i = 0
    while i < len(string_tree):
        char = string_tree[i]
        if skip_char(char):
            pass
        elif char == ")" and not string_tree[i - 1] == " ":
            # closing symbol of production
            _node_number = q.get(block=False)
            _node_children = q_children.get(block=False)
            G.nodes[_node_number]["children"] = _node_children
        else:
            j = find_longest_match_func(i)
            sym = string_tree[i:j]
            i = j - 1
            node_number += 1
            G.add_node(
                node_number,
                **{
                    sym_name: sym,
                    "terminal": sym in grammar.terminals,
                    "children": [],
                },
            )
            if not q.empty():
                G.add_edge(q.queue[-1], node_number)
                q_children.queue[-1].append(node_number)
            if sym in grammar.nonterminals:
                q.put(node_number)
                q_children.put([])
        i += 1

    if len(q.queue) != 0:
        raise Exception("Invalid string_tree")
    return G

get_all_edge_data #

get_all_edge_data(
    key: str, scope="all", private_edge_data: bool = False
) -> list

Get edge attributes of this graph and all child graphs in one go.

PARAMETER DESCRIPTION
key

The key of the attribute

TYPE: str

scope

The scope to be applied

TYPE: str DEFAULT: 'all'

private_edge_data

Whether to return data from graph copies as well.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
list

All data in a list.

TYPE: list

Source code in neps/search_spaces/architecture/graph.py
def get_all_edge_data(
    self, key: str, scope="all", private_edge_data: bool = False
) -> list:
    """
    Get edge attributes of this graph and all child graphs in one go.

    Args:
        key (str): The key of the attribute
        scope (str): The scope to be applied
        private_edge_data (bool): Whether to return data from graph copies as well.

    Returns:
        list: All data in a list.
    """
    assert scope is not None
    result = []
    for graph in self._get_child_graphs(single_instances=not private_edge_data) + [
        self
    ]:
        if (
            scope == "all"
            or graph.scope == scope
            or (isinstance(scope, list) and graph.scope in scope)
        ):
            for _, _, edge_data in graph.edges.data():
                if edge_data.has(key):
                    result.append(edge_data[key])
    return result

get_dense_edges #

get_dense_edges()

Returns the edge indices (i, j) that would make a fully connected DAG without circles such that i < j and i != j. Assumes nodes are already created.

RETURNS DESCRIPTION
list

list of edge indices.

Source code in neps/search_spaces/architecture/graph.py
def get_dense_edges(self):
    """
    Returns the edge indices (i, j) that would make a fully connected
    DAG without circles such that i < j and i != j. Assumes nodes are
    already created.

    Returns:
        list: list of edge indices.
    """
    edges = []
    nodes = sorted(list(self.nodes()))
    for i in nodes:
        for j in nodes:
            if i != j and j > i:
                edges.append((i, j))
    return edges

get_graph_representation #

get_graph_representation(
    identifier: str, grammar: Grammar, edge_attr: bool
) -> DiGraph

This functions takes an identifier and constructs the (multi-variate) composition of the functions it describes. Args: identifier (str): identifier grammar (Grammar): grammar flatten_graph (bool, optional): Whether to flatten the graph. Defaults to True. Returns: nx.DiGraph: (multi-variate) composition of functions

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def get_graph_representation(
    self,
    identifier: str,
    grammar: Grammar,
    edge_attr: bool,
) -> nx.DiGraph:
    """This functions takes an identifier and constructs the
    (multi-variate) composition of the functions it describes.
    Args:
        identifier (str): identifier
        grammar (Grammar): grammar
        flatten_graph (bool, optional): Whether to flatten the graph. Defaults to True.
    Returns:
        nx.DiGraph: (multi-variate) composition of functions
    """

    def _skip_char(char: str) -> bool:
        return True if char in [" ", "\t", "\n", "[", "]"] else False

    def _get_sym_from_split(split: str) -> str:
        start_idx, end_idx = 0, len(split)
        while start_idx < end_idx and split[start_idx] == "(":
            start_idx += 1
        while start_idx < end_idx and split[end_idx - 1] == ")":
            end_idx -= 1
        return split[start_idx:end_idx]

    def to_node_attributed_edge_list(
        edge_list: list[tuple],
    ) -> tuple[list[tuple[int, int]], dict]:
        first_nodes = {e[0] for e in edge_list}
        second_nodes = {e[1] for e in edge_list}
        src = first_nodes - second_nodes
        tgt = second_nodes - first_nodes
        node_offset = len(src)
        edge_to_node_map = {e: i + node_offset for i, e in enumerate(edge_list)}
        node_list = []
        for e in edge_list:
            ni = edge_to_node_map[e]
            u, v = e
            if u in src:
                node_list.append((u, ni))
            if v in tgt:
                node_list.append((ni, v))

            for e_ in filter(
                lambda e: (e[1] == u), edge_list
            ):
                node_list.append((edge_to_node_map[e_], ni))

        return node_list, edge_to_node_map

    descriptor = self.id_to_string_tree(identifier)

    if edge_attr:
        terminal_to_graph = self.terminal_to_graph_edges
    else:  # node-attr
        terminal_to_graph_nodes = {
            k: to_node_attributed_edge_list(edge_list) if edge_list else (None, None)
            for k, edge_list in self.terminal_to_graph_edges.items()
        }
        terminal_to_graph = {k: v[0] for k, v in terminal_to_graph_nodes.items()}
        # edge_to_node_map = {k: v[1] for k, v in terminal_to_graph_nodes.items()}

    q_nonterminals: queue.LifoQueue = queue.LifoQueue()
    q_topologies: queue.LifoQueue = queue.LifoQueue()
    q_primitives: queue.LifoQueue = queue.LifoQueue()

    G = nx.DiGraph()
    for _, split in enumerate(descriptor.split(" ")):
        if _skip_char(split):
            continue
        sym = _get_sym_from_split(split)

        if sym in grammar.terminals:
            is_topology = False
            if inspect.isclass(self.terminal_to_op_names[sym]) and issubclass(
                self.terminal_to_op_names[sym], AbstractTopology
            ):
                is_topology = True
            elif isinstance(self.terminal_to_op_names[sym], partial) and issubclass(
                self.terminal_to_op_names[sym].func, AbstractTopology
            ):
                is_topology = True

            if is_topology:
                q_topologies.put([self.terminal_to_op_names[sym], 0])
            else:  # is primitive operation
                q_primitives.put(self.terminal_to_op_names[sym])
                q_topologies.queue[-1][1] += 1  # count number of primitives
        elif sym in grammar.nonterminals:
            q_nonterminals.put(sym)
        else:
            raise Exception(f"Unknown symbol {sym}")

        if ")" in split:
            # closing symbol of production
            while ")" in split:
                if q_nonterminals.qsize() == q_topologies.qsize():
                    topology, number_of_primitives = q_topologies.get(block=False)
                    primitives = [
                        q_primitives.get(block=False)
                        for _ in range(number_of_primitives)
                    ][::-1]
                    if (
                        topology in terminal_to_graph
                        and terminal_to_graph[topology] is not None
                    ):
                        raise NotImplementedError
                        # edges = terminal_to_graph[topology]
                    elif isinstance(topology, partial):
                        raise NotImplementedError
                    else:
                        composed_function = topology(*primitives)
                        node_attr_dag = composed_function.get_node_list_and_ops()
                        G = node_attr_dag  # TODO only works for DARTS for now

                    if not q_topologies.empty():
                        q_primitives.put(composed_function)
                        q_topologies.queue[-1][1] += 1

                _ = q_nonterminals.get(block=False)
                split = split[:-1]

    if not q_topologies.empty():
        raise Exception("Invalid descriptor")

    # G = self.prune_unconnected_parts(G, src_node, sink_node)
    # self._check_graph(G)
    return G

graph_to_self #

graph_to_self(graph: DiGraph, clear_self: bool = True)

Copies graph to self

PARAMETER DESCRIPTION
graph

graph

TYPE: DiGraph

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def graph_to_self(self, graph: nx.DiGraph, clear_self: bool = True):
    """Copies graph to self

    Args:
        graph (nx.DiGraph): graph
    """
    if clear_self:
        self.clear()
    for u, v, data in graph.edges(data=True):
        self.add_edge(u, v)  # type: ignore[union-attr]
        self.edges[u, v].update(data)  # type: ignore[union-attr]
    for n, data in graph.nodes(data=True):
        self.nodes[n].update(**data)

modules_str #

modules_str()

Once the graph has been parsed, prints the modules as they appear in pytorch.

Source code in neps/search_spaces/architecture/graph.py
def modules_str(self):
    """
    Once the graph has been parsed, prints the modules as they appear in pytorch.
    """
    if self.is_parsed:
        result = ""
        for g in self._get_child_graphs(single_instances=True) + [self]:
            result += "Graph {}:\n {}\n==========\n".format(
                g.name, torch.nn.Module.__repr__(g)
            )
        return result
    else:
        return self.__repr__()

num_input_nodes #

num_input_nodes() -> int

The number of input nodes, i.e. the nodes without an incoming edge.

RETURNS DESCRIPTION
int

Number of input nodes.

TYPE: int

Source code in neps/search_spaces/architecture/graph.py
def num_input_nodes(self) -> int:
    """
    The number of input nodes, i.e. the nodes without an
    incoming edge.

    Returns:
        int: Number of input nodes.
    """
    return sum(self.in_degree(n) == 0 for n in self.nodes)

parse #

parse()

Convert the graph into a neural network which can then be optimized by pytorch.

Source code in neps/search_spaces/architecture/graph.py
def parse(self):
    """
    Convert the graph into a neural network which can then
    be optimized by pytorch.
    """
    for node_idx in lexicographical_topological_sort(self):
        if "subgraph" in self.nodes[node_idx]:
            self.nodes[node_idx]["subgraph"].parse()
            self.add_module(
                f"{self.name}-subgraph_at({node_idx})",
                self.nodes[node_idx]["subgraph"],
            )
        else:
            if isinstance(self.nodes[node_idx]["comb_op"], torch.nn.Module):
                self.add_module(
                    f"{self.name}-comb_op_at({node_idx})",
                    self.nodes[node_idx]["comb_op"],
                )
        for neigbor_idx in self.neighbors(node_idx):
            edge_data = self.get_edge_data(node_idx, neigbor_idx)
            if isinstance(edge_data.op, Graph):
                edge_data.op.parse()
            elif edge_data.op.get_embedded_ops():
                for primitive in edge_data.op.get_embedded_ops():
                    if isinstance(primitive, Graph):
                        primitive.parse()
            self.add_module(
                f"{self.name}-edge({node_idx},{neigbor_idx})",
                edge_data.op,
            )
    self.is_parsed = True

prepare_discretization #

prepare_discretization()

In some cases the search space is manipulated before the final discretization is happening, e.g. DARTS. In such chases this should be defined in the search space, so all optimizers can call it.

Source code in neps/search_spaces/architecture/graph.py
def prepare_discretization(self):
    """
    In some cases the search space is manipulated before the final
    discretization is happening, e.g. DARTS. In such chases this should
    be defined in the search space, so all optimizers can call it.
    """

prepare_evaluation #

prepare_evaluation()

In some cases the evaluation architecture does not match the searched one. An example is where the makro_model is extended to increase the parameters. This is done here.

Source code in neps/search_spaces/architecture/graph.py
def prepare_evaluation(self):
    """
    In some cases the evaluation architecture does not match the searched
    one. An example is where the makro_model is extended to increase the
    parameters. This is done here.
    """

prune_tree #

prune_tree(
    tree: DiGraph,
    terminal_to_torch_map_keys: KeysView,
    node_label: str = "op_name",
) -> DiGraph

Prunes unnecessary parts of parse tree, i.e., only one child

PARAMETER DESCRIPTION
tree

Parse tree

TYPE: DiGraph

RETURNS DESCRIPTION
DiGraph

nx.DiGraph: Pruned parse tree

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def prune_tree(
    self,
    tree: nx.DiGraph,
    terminal_to_torch_map_keys: collections.abc.KeysView,
    node_label: str = "op_name",
) -> nx.DiGraph:
    """Prunes unnecessary parts of parse tree, i.e., only one child

    Args:
        tree (nx.DiGraph): Parse tree

    Returns:
        nx.DiGraph: Pruned parse tree
    """

    def dfs(visited: set, tree: nx.DiGraph, node: int) -> nx.DiGraph:
        if node not in visited:
            visited.add(node)

            i = 0
            while i < len(tree.nodes[node]["children"]):
                former_len = len(tree.nodes[node]["children"])
                child = tree.nodes[node]["children"][i]
                tree = dfs(
                    visited,
                    tree,
                    child,
                )
                if former_len == len(tree.nodes[node]["children"]):
                    i += 1

            if len(tree.nodes[node]["children"]) == 1:
                predecessor = list(tree.pred[node])
                if len(predecessor) > 0:
                    tree.add_edge(predecessor[0], tree.nodes[node]["children"][0])
                    old_children = tree.nodes[predecessor[0]]["children"]
                    idx = [i for i, c in enumerate(old_children) if c == node][0]
                    tree.nodes[predecessor[0]]["children"] = (
                        old_children[: idx + 1]
                        + [tree.nodes[node]["children"][0]]
                        + old_children[idx + 1 :]
                    )
                    tree.nodes[predecessor[0]]["children"].remove(node)

                tree.remove_node(node)
            elif (
                tree.nodes[node]["terminal"]
                and tree.nodes[node][node_label] not in terminal_to_torch_map_keys
            ):
                predecessor = list(tree.pred[node])[0]
                tree.nodes[predecessor]["children"].remove(node)
                tree.remove_node(node)
        return tree

    return dfs(set(), tree, self._find_root(tree))

reset_weights #

reset_weights(inplace: bool = False)

Resets the weights for the 'op' at all edges.

PARAMETER DESCRIPTION
inplace

Do the operation in place or return a modified copy.

TYPE: bool DEFAULT: False

Returns: Graph: Returns the modified version of the graph.

Source code in neps/search_spaces/architecture/graph.py
def reset_weights(self, inplace: bool = False):
    """
    Resets the weights for the 'op' at all edges.

    Args:
        inplace (bool): Do the operation in place or
            return a modified copy.
    Returns:
        Graph: Returns the modified version of the graph.
    """

    def weight_reset(m):
        if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
            m.reset_parameters()

    if inplace:
        graph = self
    else:
        graph = self.clone()

    graph.apply(weight_reset)

    return graph

set_at_edges #

set_at_edges(key, value, shared=False)

Sets the attribute for all edges in this and any child graph

Source code in neps/search_spaces/architecture/graph.py
def set_at_edges(self, key, value, shared=False):
    """
    Sets the attribute for all edges in this and any child graph
    """
    for graph in self._get_child_graphs(single_instances=shared) + [self]:
        logger.debug(f"Updating edges of graph {graph.name}")
        for _, _, edge_data in graph.edges.data():
            if not edge_data.is_final():
                edge_data.set(key, value, shared)

set_input #

set_input(node_idxs: list)

Route the input from specific parent edges to the input nodes of this subgraph. Inputs are assigned in lexicographical order.

Example: - Parent node (i.e. node where self is located on) has two incoming edges from nodes 3 and 5. - self has two input nodes 1 and 2 (i.e. nodes without an incoming edge) - node_idxs = [5, 3] Then input of node 5 is routed to node 1 and input of node 3 is routed to node 2.

Similarly, if node_idxs = [5, 5] then input of node 5 is routed to both node 1 and 2. Warning: In this case the output of another incoming edge is ignored!

Should be used in a builder-like pattern: 'subgraph'=Graph().set_input([5, 3])

PARAMETER DESCRIPTION
node_idx

The index of the nodes where the data is coming from.

TYPE: list

RETURNS DESCRIPTION
Graph

self with input node indices set.

Source code in neps/search_spaces/architecture/graph.py
def set_input(self, node_idxs: list):
    """
    Route the input from specific parent edges to the input nodes of
    this subgraph. Inputs are assigned in lexicographical order.

    Example:
    - Parent node (i.e. node where `self` is located on) has two
      incoming edges from nodes 3 and 5.
    - `self` has two input nodes 1 and 2 (i.e. nodes without
      an incoming edge)
    - `node_idxs = [5, 3]`
    Then input of node 5 is routed to node 1 and input of node 3
    is routed to node 2.

    Similarly, if `node_idxs = [5, 5]` then input of node 5 is routed
    to both node 1 and 2. Warning: In this case the output of another
    incoming edge is ignored!

    Should be used in a builder-like pattern: `'subgraph'=Graph().set_input([5, 3])`

    Args:
        node_idx (list): The index of the nodes where the data is coming from.

    Returns:
        Graph: self with input node indices set.

    """
    num_innodes = sum(self.in_degree(n) == 0 for n in self.nodes)
    assert num_innodes == len(
        node_idxs
    ), "Expecting node index for every input node. Excpected {}, got {}".format(
        num_innodes, len(node_idxs)
    )
    self.input_node_idxs = node_idxs  # type: ignore[assignment]
    return self

set_scope #

set_scope(scope: str, recursively=True)

Sets the scope of this instance of the graph.

The function should be used in a builder-like pattern 'subgraph'=Graph().set_scope("scope").

PARAMETER DESCRIPTION
scope

the scope

TYPE: str

recursively

Also set the scope for all child graphs. default True

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Graph

self with the setted scope.

Source code in neps/search_spaces/architecture/graph.py
def set_scope(self, scope: str, recursively=True):
    """
    Sets the scope of this instance of the graph.

    The function should be used in a builder-like pattern
    `'subgraph'=Graph().set_scope("scope")`.

    Args:
        scope (str): the scope
        recursively (bool): Also set the scope for all child graphs.
            default True

    Returns:
        Graph: self with the setted scope.
    """
    self.scope = scope
    if recursively:
        for g in self._get_child_graphs(single_instances=False):
            g.scope = scope
    return self

to_graph_repr #

to_graph_repr(graph: Graph, edge_attr: bool) -> DiGraph

Transforms NASLib-esque graph to NetworkX graph.

PARAMETER DESCRIPTION
graph

NASLib-esque graph.

TYPE: Graph

edge_attr

Transform to edge attribution or node attribution.

TYPE: bool

RETURNS DESCRIPTION
DiGraph

nx.DiGraph: edge- or node-attributed representation of computational graph.

Source code in neps/search_spaces/architecture/core_graph_grammar.py
def to_graph_repr(self, graph: Graph, edge_attr: bool) -> nx.DiGraph:
    """Transforms NASLib-esque graph to NetworkX graph.

    Args:
        graph (Graph): NASLib-esque graph.
        edge_attr (bool): Transform to edge attribution or node attribution.

    Returns:
        nx.DiGraph: edge- or node-attributed representation of computational graph.
    """
    if edge_attr:
        g = nx.DiGraph()
        g.add_nodes_from(graph.nodes())
        for u, v in graph.edges():
            if isinstance(graph.edges[u, v]["op"], Graph):
                g.add_edge(u, v, op_name=graph.edges[u, v]["op"].name)
            else:
                g.add_edge(
                    u, v, **{self.edge_label: graph.edges[u, v][self.edge_label]}
                )
        g.graph_type = "edge_attr"
    else:
        g = nx.DiGraph()
        src = [n for n in graph.nodes() if graph.in_degree(n) == 0][0]
        tgt = [n for n in graph.nodes() if graph.out_degree(n) == 0][0]
        nof_edges = graph.size()
        g.add_nodes_from(
            [
                (0, {self.edge_label: "input"}),
                (nof_edges + 1, {self.edge_label: "output"}),
            ]
        )
        node_counter = 1
        open_edge: dict = {}
        for node in nx.topological_sort(graph):
            for edge in graph.out_edges(node):
                g.add_node(
                    node_counter,
                    **{self.edge_label: graph.edges[edge][self.edge_label]},
                )

                u, v = edge
                if u == src:  # special case for input node
                    g.add_edge(0, node_counter)
                if v == tgt:  # special case of output node
                    g.add_edge(node_counter, nof_edges + 1)
                if (
                    u in open_edge.keys()
                ):  # add edge between already seen nodes and new node
                    for node_count in open_edge[u]:
                        g.add_edge(node_count, node_counter)

                if v in open_edge.keys():
                    open_edge[v].append(node_counter)
                else:
                    open_edge[v] = [node_counter]
                node_counter += 1
        g.graph_type = "node_attr"

    self._check_graph(g)

    return g

unparse #

unparse()

Undo the pytorch parsing by reconstructing the graph uusing the networkx data structures.

This is done recursively also for child graphs.

RETURNS DESCRIPTION
Graph

An unparsed shallow copy of the graph.

Source code in neps/search_spaces/architecture/graph.py
def unparse(self):
    """
    Undo the pytorch parsing by reconstructing the graph uusing the
    networkx data structures.

    This is done recursively also for child graphs.

    Returns:
        Graph: An unparsed shallow copy of the graph.
    """
    g = self.__class__()
    g.clear()

    graph_nodes = self.nodes
    graph_edges = self.edges

    # unparse possible child graphs
    # be careful with copying/deepcopying here cause of shared edge data
    for _, data in graph_nodes.data():
        if "subgraph" in data:
            data["subgraph"] = data["subgraph"].unparse()
    for _, _, data in graph_edges.data():
        if isinstance(data.op, Graph):
            data.set("op", data.op.unparse())

    # create the new graph
    # Remember to add all members here to update. I know it is ugly but don't know better
    g.add_nodes_from(graph_nodes.data())
    g.add_edges_from(graph_edges.data())
    g.graph.update(self.graph)
    g.name = self.name
    g.input_node_idxs = self.input_node_idxs
    g.scope = self.scope
    g.is_parsed = False
    g._id = self._id
    g.OPTIMIZER_SCOPE = self.OPTIMIZER_SCOPE
    g.QUERYABLE = self.QUERYABLE

    return g

update_edges #

update_edges(
    update_func: Callable,
    scope="all",
    private_edge_data: bool = False,
)

This updates the edge data of this graph and all child graphs. This is the preferred way to manipulate the edges after the definition of the graph, e.g. by optimizers who want to insert their own op. update_func(current_edge_data). This way optimizers can initialize and store necessary information at edges.

Note that edges marked as 'final' will not be updated here.

PARAMETER DESCRIPTION
update_func

Function which accepts one argument called current_edge_data. and returns the modified EdgeData object.

TYPE: callable

scope

Can be "all" or list of scopes to be updated.

TYPE: str or list(str DEFAULT: 'all'

private_edge_data

If set to true, this means update_func will be applied to all edges. THIS IS NOT RECOMMENDED FOR SHARED ATTRIBUTES. Shared attributes should be set only once, we take care it is syncronized across all copies of this graph.

The only usecase for setting it to true is when actually changing op during the initialization of the optimizer (e.g. replacing it with MixedOp or SampleOp)

TYPE: bool DEFAULT: False

Source code in neps/search_spaces/architecture/graph.py
def update_edges(
    self, update_func: Callable, scope="all", private_edge_data: bool = False
):
    """
    This updates the edge data of this graph and all child graphs.
    This is the preferred way to manipulate the edges after the definition
    of the graph, e.g. by optimizers who want to insert their own op.
    `update_func(current_edge_data)`. This way optimizers
    can initialize and store necessary information at edges.

    Note that edges marked as 'final' will not be updated here.

    Args:
        update_func (callable): Function which accepts one argument called `current_edge_data`.
            and returns the modified EdgeData object.
        scope (str or list(str)): Can be "all" or list of scopes to be updated.
        private_edge_data (bool): If set to true, this means update_func will be
            applied to all edges. THIS IS NOT RECOMMENDED FOR SHARED
            ATTRIBUTES. Shared attributes should be set only once, we
            take care it is syncronized across all copies of this graph.

            The only usecase for setting it to true is when actually changing
            `op` during the initialization of the optimizer (e.g. replacing it
            with MixedOp or SampleOp)
    """
    Graph._verify_update_function(update_func, private_edge_data)
    assert scope is not None
    for graph in self._get_child_graphs(single_instances=not private_edge_data) + [
        self
    ]:
        if (
            scope == "all"
            or scope == graph.scope
            or (isinstance(scope, list) and graph.scope in scope)
        ):
            logger.debug(f"Updating edges of graph {graph.name}")
            for u, v, edge_data in graph.edges.data():
                if not edge_data.is_final():
                    edge = AttrDict(head=u, tail=v, data=edge_data)
                    update_func(edge=edge)
    self._delete_flagged_edges()

update_nodes #

update_nodes(
    update_func: Callable,
    scope="all",
    single_instances: bool = True,
)

Update the nodes of the graph and its incoming and outgoing edges by iterating over the graph and applying update_func to each of it. This is the preferred way to change the search space once it has been defined.

Note that edges marked as 'final' will not be updated here.

PARAMETER DESCRIPTION
update_func

Function that accepts three incoming parameters named node, in_edges, out_edges. - node is a tuple (int, dict) containing the index and the attributes of the current node. - in_edges is a list of tuples with the index of the tail of the edge and its EdgeData. - `out_edges is a list of tuples with the index of the head of the edge and its EdgeData.

TYPE: callable

scope

Can be "all" or list of scopes to be updated. Only graphs and child graphs with the specified scope are considered

TYPE: str or list(str DEFAULT: 'all'

single_instance

If set to false, this means update_func will be applied to nodes of all copies of a graphs. THIS IS NOT RECOMMENDED FOR SHARED ATTRIBUTES, i.e. when manipulating the shared data of incoming or outgoing edges. Shared attributes should be set only once, we take care it is syncronized across all copies of this graph.

The only usecase for setting it to true is when actually changing op during the initialization of the optimizer (e.g. replacing it with MixedOp or SampleOp)

TYPE: bool

Source code in neps/search_spaces/architecture/graph.py
def update_nodes(
    self, update_func: Callable, scope="all", single_instances: bool = True
):
    """
    Update the nodes of the graph and its incoming and outgoing edges by iterating over the
    graph and applying `update_func` to each of it. This is the
    preferred way to change the search space once it has been defined.

    Note that edges marked as 'final' will not be updated here.

    Args:
        update_func (callable): Function that accepts three incoming parameters named
            `node, in_edges, out_edges`.
                - `node` is a tuple (int, dict) containing the
                  index and the attributes of the current node.
                - `in_edges` is a list of tuples with the index of
                  the tail of the edge and its EdgeData.
                - `out_edges is a list of tuples with the index of
                  the head of the edge and its EdgeData.
        scope (str or list(str)): Can be "all" or list of scopes to be updated. Only graphs
            and child graphs with the specified scope are considered
        single_instance (bool): If set to false, this means update_func will be
            applied to nodes of all copies of a graphs. THIS IS NOT RECOMMENDED FOR SHARED
            ATTRIBUTES, i.e. when manipulating the shared data of incoming or outgoing edges.
            Shared attributes should be set only once, we take care it is syncronized across
            all copies of this graph.

            The only usecase for setting it to true is when actually changing
            `op` during the initialization of the optimizer (e.g. replacing it
            with MixedOp or SampleOp)
    """
    assert scope is not None
    for graph in self._get_child_graphs(single_instances) + [self]:
        if (
            scope == "all"
            or graph.scope == scope
            or (isinstance(scope, list) and graph.scope in scope)
        ):
            logger.debug(f"Updating nodes of graph {graph.name}")
            for node_idx in lexicographical_topological_sort(graph):
                node = (node_idx, graph.nodes[node_idx])
                in_edges = list(graph.in_edges(node_idx, data=True))  # (v, u, data)
                in_edges = [
                    (v, data) for v, u, data in in_edges if not data.is_final()
                ]  # u is same for all
                out_edges = list(
                    graph.out_edges(node_idx, data=True)
                )  # (v, u, data)
                out_edges = [
                    (u, data) for v, u, data in out_edges if not data.is_final()
                ]  # v is same for all
                update_func(node=node, in_edges=in_edges, out_edges=out_edges)
    self._delete_flagged_edges()