Skip to content

node_collection

ArtifactNodeCollection dataclass

Bases: UserCodeNodeCollection

ArtifactNodeCollection is a special subclass of UserCodeNodeCollection which return Artifacts.

If is_pre_computed is True, this means that this NodeCollection should use a precomputed Artifact's value.

Source code in lineapy/graph_reader/node_collection.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@dataclass
class ArtifactNodeCollection(UserCodeNodeCollection):
    """
    ArtifactNodeCollection is a special subclass of UserCodeNodeCollection which return Artifacts.

    If is_pre_computed is True, this means that this NodeCollection should use a precomputed
    Artifact's value.
    """

    is_pre_computed: bool = False
    pre_computed_artifact: Optional[LineaArtifactDef] = None

    def get_function_definition(
        self, graph: Graph, include_non_slice_as_comment=False, indentation=4
    ) -> str:
        """
        Return a function body code block from code in the graph segment.

        If self.is_pre_computed_artifact is True, will replace the calculation
        block with lineapy.get().get_value()
        """

        if not self.is_pre_computed:
            return super().get_function_definition(
                graph, include_non_slice_as_comment, indentation
            )
        else:
            assert self.pre_computed_artifact is not None
            indentation_block = " " * indentation
            name = self.safename
            return_string = ", ".join([v for v in self.return_variables])
            artifact_codeblock = (
                f"{indentation_block}import lineapy\n{indentation_block}"
            )
            artifact_codeblock += f'{return_string}=lineapy.get("{self.pre_computed_artifact["artifact_name"]}", {self.pre_computed_artifact["version"]}).get_value()'
            args_string = ""

            return f"def get_{name}({args_string}):\n{artifact_codeblock}\n{indentation_block}return {return_string}"

get_function_definition(graph, include_non_slice_as_comment=False, indentation=4)

Return a function body code block from code in the graph segment.

If self.is_pre_computed_artifact is True, will replace the calculation block with lineapy.get().get_value()

Source code in lineapy/graph_reader/node_collection.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def get_function_definition(
    self, graph: Graph, include_non_slice_as_comment=False, indentation=4
) -> str:
    """
    Return a function body code block from code in the graph segment.

    If self.is_pre_computed_artifact is True, will replace the calculation
    block with lineapy.get().get_value()
    """

    if not self.is_pre_computed:
        return super().get_function_definition(
            graph, include_non_slice_as_comment, indentation
        )
    else:
        assert self.pre_computed_artifact is not None
        indentation_block = " " * indentation
        name = self.safename
        return_string = ", ".join([v for v in self.return_variables])
        artifact_codeblock = (
            f"{indentation_block}import lineapy\n{indentation_block}"
        )
        artifact_codeblock += f'{return_string}=lineapy.get("{self.pre_computed_artifact["artifact_name"]}", {self.pre_computed_artifact["version"]}).get_value()'
        args_string = ""

        return f"def get_{name}({args_string}):\n{artifact_codeblock}\n{indentation_block}return {return_string}"

BaseNodeCollection dataclass

BaseNodeCollection represents a collection of Nodes in a Graph.

Used for defining modules and functions.

Source code in lineapy/graph_reader/node_collection.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@dataclass
class BaseNodeCollection:
    """
    BaseNodeCollection represents a collection of Nodes in a Graph.

    Used for defining modules and functions.
    """

    node_list: Set[LineaID]
    name: str

    def __post_init__(self):
        self.safename = slugify(self.name)

    def _get_raw_codeblock(
        self, graph: Graph, include_non_slice_as_comment=False
    ) -> str:
        return get_source_code_from_graph(
            self.node_list, graph, include_non_slice_as_comment
        ).__str__()

ImportNodeCollection

Bases: BaseNodeCollection

ImportNodeCollection contains all the nodes used to import libraries in a Session.

Source code in lineapy/graph_reader/node_collection.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class ImportNodeCollection(BaseNodeCollection):
    """
    ImportNodeCollection contains all the nodes used to import libraries in a Session.
    """

    def get_import_block(self, graph: Graph, indentation=0) -> str:
        """
        Return a code block for import statement of the graph segment
        """
        raw_codeblock = self._get_raw_codeblock(graph)
        if raw_codeblock == "":
            return ""

        indentation_block = " " * indentation
        import_codeblock = "\n".join(
            [
                f"{indentation_block}{line}"
                for line in raw_codeblock.split("\n")
                if len(line.strip(" ")) > 0
            ]
        )
        if len(import_codeblock) > 0:
            import_codeblock += "\n" * 2
        return import_codeblock

get_import_block(graph, indentation=0)

Return a code block for import statement of the graph segment

Source code in lineapy/graph_reader/node_collection.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def get_import_block(self, graph: Graph, indentation=0) -> str:
    """
    Return a code block for import statement of the graph segment
    """
    raw_codeblock = self._get_raw_codeblock(graph)
    if raw_codeblock == "":
        return ""

    indentation_block = " " * indentation
    import_codeblock = "\n".join(
        [
            f"{indentation_block}{line}"
            for line in raw_codeblock.split("\n")
            if len(line.strip(" ")) > 0
        ]
    )
    if len(import_codeblock) > 0:
        import_codeblock += "\n" * 2
    return import_codeblock

InputVarNodeCollection

Bases: BaseNodeCollection

InputVarNodeCollection contains all the nodes that are needed as input parameters to the Session.

Source code in lineapy/graph_reader/node_collection.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class InputVarNodeCollection(BaseNodeCollection):
    """
    InputVarNodeCollection contains all the nodes that are needed as input parameters
    to the Session.
    """

    def get_input_parameters_block(self, graph: Graph, indentation=4) -> str:
        """
        Return a code block for input parameters of the graph segment
        """
        raw_codeblock = self._get_raw_codeblock(graph)
        if raw_codeblock == "":
            return ""

        indentation_block = " " * indentation
        input_parameters_lines = raw_codeblock.rstrip("\n").split("\n")

        if len(input_parameters_lines) > 1:
            input_parameters_codeblock = "\n" + "".join(
                [
                    f"{indentation_block}{line},\n"
                    for line in input_parameters_lines
                ]
            )
        elif len(input_parameters_lines) == 1:
            input_parameters_codeblock = input_parameters_lines[0]
        else:
            input_parameters_codeblock = ""

        return input_parameters_codeblock

get_input_parameters_block(graph, indentation=4)

Return a code block for input parameters of the graph segment

Source code in lineapy/graph_reader/node_collection.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def get_input_parameters_block(self, graph: Graph, indentation=4) -> str:
    """
    Return a code block for input parameters of the graph segment
    """
    raw_codeblock = self._get_raw_codeblock(graph)
    if raw_codeblock == "":
        return ""

    indentation_block = " " * indentation
    input_parameters_lines = raw_codeblock.rstrip("\n").split("\n")

    if len(input_parameters_lines) > 1:
        input_parameters_codeblock = "\n" + "".join(
            [
                f"{indentation_block}{line},\n"
                for line in input_parameters_lines
            ]
        )
    elif len(input_parameters_lines) == 1:
        input_parameters_codeblock = input_parameters_lines[0]
    else:
        input_parameters_codeblock = ""

    return input_parameters_codeblock

NodeInfo dataclass

Parameters:

Name Type Description Default
assigned_variables Set[str]

variables assigned at this node

field(default_factory=set)
assigned_artifact Optional[str]

this node is pointing to some artifact

field(default=None)
dependent_variables Set[str]

union of if any variable is assigned at predecessor node, use the assigned variables. otherwise, use the dependent_variables

field(default_factory=set)
tracked_variables Set[str]

variables that this node is pointing to

field(default_factory=set)
predecessors Set[LineaID]

predecessors of the node

field(default_factory=set)
module_import Set[str]

module name/alias that this node is point to

field(default_factory=set)
artifact_name Optional[str]

this node belong to which artifact calculating block

field(default=None)
Source code in lineapy/graph_reader/node_collection.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
@dataclass
class NodeInfo:
    """
    Parameters
    ----------
    assigned_variables: Set[str]
        variables assigned at this node
    assigned_artifact: Optional[str]
        this node is pointing to some artifact
    dependent_variables: Set[str]
        union of if any variable is assigned at predecessor node,
        use the assigned variables. otherwise, use the dependent_variables
    tracked_variables: Set[str]
        variables that this node is pointing to
    predecessors: Set[LineaID]
        predecessors of the node
    module_import: Set[str]
        module name/alias that this node is point to
    artifact_name: Optional[str]
        this node belong to which artifact calculating block
    """

    assigned_variables: Set[str] = field(default_factory=set)
    assigned_artifact: Optional[str] = field(default=None)
    dependent_variables: Set[str] = field(default_factory=set)
    predecessors: Set[LineaID] = field(default_factory=set)
    tracked_variables: Set[str] = field(default_factory=set)
    module_import: Set[str] = field(default_factory=set)
    artifact_name: Optional[str] = field(default=None)

UserCodeNodeCollection dataclass

Bases: BaseNodeCollection

This class is used for holding a set of node(as a subgraph) corresponding to user code that can be sliced on.

It is initiated with list of nodes::

seg = NodeCollection(node_list)

For variable calculation calculation purpose, it can identify all variables related to these by running::

seg._update_variable_info()

For all code generating purpose, it need to initiate a real graph objects by::

seg.update_raw_codeblock()
Source code in lineapy/graph_reader/node_collection.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@dataclass
class UserCodeNodeCollection(BaseNodeCollection):
    """
    This class is used for holding a set of node(as a subgraph)
    corresponding to user code that can be sliced on.

    It is initiated with list of nodes::

        seg = NodeCollection(node_list)

    For variable calculation calculation purpose, it can identify all variables
    related to these by running::

        seg._update_variable_info()

    For all code generating purpose, it need to initiate a real graph objects by::

        seg.update_raw_codeblock()
    """

    assigned_variables: Set[str] = field(default_factory=set)
    dependent_variables: Set[str] = field(default_factory=set)
    all_variables: Set[str] = field(default_factory=set)
    input_variables: Set[str] = field(default_factory=set)
    tracked_variables: Set[str] = field(default_factory=set)
    predecessor_nodes: Set[LineaID] = field(default_factory=set)
    # Need to be a list to keep return order
    return_variables: List[str] = field(default_factory=list)

    def get_input_variable_sources(self, node_context) -> Dict[str, Set[str]]:
        """
        Get information about which input variable is originated from which artifact.
        """
        input_variable_sources: Dict[str, Set[str]] = dict()
        for pred_id in self.predecessor_nodes:
            pred_variables = (
                node_context[pred_id].assigned_variables
                if len(node_context[pred_id].assigned_variables) > 0
                else node_context[pred_id].tracked_variables
            )
            pred_art = node_context[pred_id].artifact_name
            assert isinstance(pred_art, str)
            if pred_art != "module_import":
                input_variable_sources[pred_art] = input_variable_sources.get(
                    pred_art, set()
                ).union(pred_variables)
        return input_variable_sources

    def update_variable_info(self, node_context, input_parameters_node):
        """
        Update variable information to add user defined input parameters.
        """
        self.dependent_variables = self.dependent_variables.union(
            *[node_context[nid].dependent_variables for nid in self.node_list]
        )
        # variables got assigned within these nodes
        self.assigned_variables = self.assigned_variables.union(
            *[node_context[nid].assigned_variables for nid in self.node_list]
        )
        # all variables within these nodes
        self.all_variables = self.dependent_variables.union(
            self.assigned_variables
        ).union(set(self.return_variables))
        # required input variables
        self.input_variables = self.all_variables - self.assigned_variables
        # Add user defined parameter in to input variables list
        user_input_parameters = set(
            [
                var
                for var, nid in input_parameters_node.items()
                if nid in self.node_list
            ]
        )
        self.input_variables = self.input_variables.union(
            user_input_parameters
        )

    def get_function_definition(
        self, graph: Graph, include_non_slice_as_comment=False, indentation=4
    ) -> str:
        """
        Return a standalone function to define the function of the graph segment.
        """
        indentation_block = " " * indentation
        name = self.safename
        return_string = ", ".join([v for v in self.return_variables])

        artifact_codeblock = "\n".join(
            [
                f"{indentation_block}{line}"
                for line in self._get_raw_codeblock(
                    graph, include_non_slice_as_comment
                ).split("\n")
                if len(line.strip(" ")) > 0
            ]
        )
        args_string = ", ".join(sorted([v for v in self.input_variables]))

        return f"def get_{name}({args_string}):\n{artifact_codeblock}\n{indentation_block}return {return_string}"

get_function_definition(graph, include_non_slice_as_comment=False, indentation=4)

Return a standalone function to define the function of the graph segment.

Source code in lineapy/graph_reader/node_collection.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def get_function_definition(
    self, graph: Graph, include_non_slice_as_comment=False, indentation=4
) -> str:
    """
    Return a standalone function to define the function of the graph segment.
    """
    indentation_block = " " * indentation
    name = self.safename
    return_string = ", ".join([v for v in self.return_variables])

    artifact_codeblock = "\n".join(
        [
            f"{indentation_block}{line}"
            for line in self._get_raw_codeblock(
                graph, include_non_slice_as_comment
            ).split("\n")
            if len(line.strip(" ")) > 0
        ]
    )
    args_string = ", ".join(sorted([v for v in self.input_variables]))

    return f"def get_{name}({args_string}):\n{artifact_codeblock}\n{indentation_block}return {return_string}"

get_input_variable_sources(node_context)

Get information about which input variable is originated from which artifact.

Source code in lineapy/graph_reader/node_collection.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def get_input_variable_sources(self, node_context) -> Dict[str, Set[str]]:
    """
    Get information about which input variable is originated from which artifact.
    """
    input_variable_sources: Dict[str, Set[str]] = dict()
    for pred_id in self.predecessor_nodes:
        pred_variables = (
            node_context[pred_id].assigned_variables
            if len(node_context[pred_id].assigned_variables) > 0
            else node_context[pred_id].tracked_variables
        )
        pred_art = node_context[pred_id].artifact_name
        assert isinstance(pred_art, str)
        if pred_art != "module_import":
            input_variable_sources[pred_art] = input_variable_sources.get(
                pred_art, set()
            ).union(pred_variables)
    return input_variable_sources

update_variable_info(node_context, input_parameters_node)

Update variable information to add user defined input parameters.

Source code in lineapy/graph_reader/node_collection.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def update_variable_info(self, node_context, input_parameters_node):
    """
    Update variable information to add user defined input parameters.
    """
    self.dependent_variables = self.dependent_variables.union(
        *[node_context[nid].dependent_variables for nid in self.node_list]
    )
    # variables got assigned within these nodes
    self.assigned_variables = self.assigned_variables.union(
        *[node_context[nid].assigned_variables for nid in self.node_list]
    )
    # all variables within these nodes
    self.all_variables = self.dependent_variables.union(
        self.assigned_variables
    ).union(set(self.return_variables))
    # required input variables
    self.input_variables = self.all_variables - self.assigned_variables
    # Add user defined parameter in to input variables list
    user_input_parameters = set(
        [
            var
            for var, nid in input_parameters_node.items()
            if nid in self.node_list
        ]
    )
    self.input_variables = self.input_variables.union(
        user_input_parameters
    )

Was this helpful?

Help us improve docs with your feedback!