def ArchitectureParameter(**kwargs):
"""Factory function"""
if "structure" not in kwargs:
raise ValueError("Factory function requires structure")
if not isinstance(kwargs["structure"], list) or len(kwargs["structure"]) == 1:
base = GraphGrammar
else:
base = GraphGrammarMultipleRepetitive
class _FunctionParameter(base):
def __init__(
self,
structure: Grammar
| list[Grammar]
| ConstrainedGrammar
| list[ConstrainedGrammar]
| str
| list[str]
| dict
| list[dict],
primitives: dict,
constraint_kwargs: dict | None = None,
name: str = "ArchitectureParameter",
set_recursive_attribute: Callable | None = None,
**kwargs,
):
local_vars = locals()
self.input_kwargs = {
args: local_vars[args]
for args in inspect.getfullargspec(self.__init__).args # type: ignore[misc]
if args != "self"
}
self.input_kwargs.update(**kwargs)
if isinstance(structure, list):
structures = [
_dict_structure_to_str(
st,
primitives,
repetitive_mapping=kwargs["terminal_to_sublanguage_map"]
if "terminal_to_sublanguage_map" in kwargs
else None,
)
if isinstance(st, dict)
else st
for st in structure
]
_structures = []
for st in structures:
if isinstance(st, str):
if constraint_kwargs is None:
_st = Grammar.fromstring(st)
else:
_st = ConstrainedGrammar.fromstring(st)
_st.set_constraints(**constraint_kwargs)
_structures.append(_st) # type: ignore[has-type]
structures = _structures
super().__init__(
grammars=structures,
terminal_to_op_names=primitives,
edge_attr=False,
**kwargs,
)
else:
if isinstance(structure, dict):
structure = _dict_structure_to_str(structure, primitives)
if isinstance(structure, str):
if constraint_kwargs is None:
structure = Grammar.fromstring(structure)
else:
structure = ConstrainedGrammar.fromstring(structure)
structure.set_constraints(**constraint_kwargs) # type: ignore[union-attr]
super().__init__(
grammar=structure, # type: ignore[arg-type]
terminal_to_op_names=primitives,
edge_attr=False,
**kwargs,
)
self._set_recursive_attribute = set_recursive_attribute
self.name: str = name
def to_pytorch(self) -> nn.Module:
self.clear_graph()
if len(self.nodes()) == 0:
composed_function = self.compose_functions()
# part below is required since PyTorch has no standard functional API
self.graph_to_self(composed_function)
self.prune_graph()
if self._set_recursive_attribute:
m = _build(
self, self._set_recursive_attribute
)
if m is not None:
return m
self.compile()
self.update_op_names()
return super().to_pytorch() # create PyTorch model
def to_tensorflow(self, inputs):
composed_function = self.compose_functions(flatten_graph=False)
return composed_function(inputs)
def create_new_instance_from_id(self, identifier: str):
g = ArchitectureParameter(**self.input_kwargs) # type: ignore[arg-type]
g.load_from(identifier)
return g
return _FunctionParameter(**kwargs)