Skip to content

Api

neps.search_spaces.architecture.api #

ArchitectureParameter #

ArchitectureParameter(**kwargs)

Factory function

Source code in neps/search_spaces/architecture/api.py
def ArchitectureParameter(**kwargs):
    """Factory function"""

    if "structure" not in kwargs:
        raise ValueError("Factory function requires structure")
    if not isinstance(kwargs["structure"], list) or len(kwargs["structure"]) == 1:
        base = GraphGrammar
    else:
        base = GraphGrammarMultipleRepetitive

    class _FunctionParameter(base):
        def __init__(
            self,
            structure: Grammar
            | list[Grammar]
            | ConstrainedGrammar
            | list[ConstrainedGrammar]
            | str
            | list[str]
            | dict
            | list[dict],
            primitives: dict,
            constraint_kwargs: dict | None = None,
            name: str = "ArchitectureParameter",
            set_recursive_attribute: Callable | None = None,
            **kwargs,
        ):
            local_vars = locals()
            self.input_kwargs = {
                args: local_vars[args]
                for args in inspect.getfullargspec(self.__init__).args  # type: ignore[misc]
                if args != "self"
            }
            self.input_kwargs.update(**kwargs)

            if isinstance(structure, list):
                structures = [
                    _dict_structure_to_str(
                        st,
                        primitives,
                        repetitive_mapping=kwargs["terminal_to_sublanguage_map"]
                        if "terminal_to_sublanguage_map" in kwargs
                        else None,
                    )
                    if isinstance(st, dict)
                    else st
                    for st in structure
                ]
                _structures = []
                for st in structures:
                    if isinstance(st, str):
                        if constraint_kwargs is None:
                            _st = Grammar.fromstring(st)
                        else:
                            _st = ConstrainedGrammar.fromstring(st)
                            _st.set_constraints(**constraint_kwargs)
                    _structures.append(_st)  # type: ignore[has-type]
                structures = _structures

                super().__init__(
                    grammars=structures,
                    terminal_to_op_names=primitives,
                    edge_attr=False,
                    **kwargs,
                )
            else:
                if isinstance(structure, dict):
                    structure = _dict_structure_to_str(structure, primitives)

                if isinstance(structure, str):
                    if constraint_kwargs is None:
                        structure = Grammar.fromstring(structure)
                    else:
                        structure = ConstrainedGrammar.fromstring(structure)
                        structure.set_constraints(**constraint_kwargs)  # type: ignore[union-attr]

                super().__init__(
                    grammar=structure,  # type: ignore[arg-type]
                    terminal_to_op_names=primitives,
                    edge_attr=False,
                    **kwargs,
                )

            self._set_recursive_attribute = set_recursive_attribute
            self.name: str = name

        def to_pytorch(self) -> nn.Module:
            self.clear_graph()
            if len(self.nodes()) == 0:
                composed_function = self.compose_functions()
                # part below is required since PyTorch has no standard functional API
                self.graph_to_self(composed_function)
                self.prune_graph()

                if self._set_recursive_attribute:
                    m = _build(
                        self, self._set_recursive_attribute
                    )

                if m is not None:
                    return m

                self.compile()
                self.update_op_names()
            return super().to_pytorch()  # create PyTorch model

        def to_tensorflow(self, inputs):
            composed_function = self.compose_functions(flatten_graph=False)
            return composed_function(inputs)

        def create_new_instance_from_id(self, identifier: str):
            g = ArchitectureParameter(**self.input_kwargs)  # type: ignore[arg-type]
            g.load_from(identifier)
            return g

    return _FunctionParameter(**kwargs)