Reference Guide  2.5.0
assignment_trans.py
1 # BSD 3-Clause License
2 #
3 # Copyright (c) 2021-2024, Science and Technology Facilities Council.
4 # All rights reserved.
5 #
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions are met:
8 #
9 # * Redistributions of source code must retain the above copyright notice, this
10 # list of conditions and the following disclaimer.
11 #
12 # * Redistributions in binary form must reproduce the above copyright notice,
13 # this list of conditions and the following disclaimer in the documentation
14 # and/or other materials provided with the distribution.
15 #
16 # * Neither the name of the copyright holder nor the names of its
17 # contributors may be used to endorse or promote products derived from
18 # this software without specific prior written permission.
19 #
20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23 # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24 # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25 # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26 # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30 # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31 # POSSIBILITY OF SUCH DAMAGE.
32 # -----------------------------------------------------------------------------
33 # Authors: R. W. Ford, A. R. Porter, N. Nobre and S. Siso, STFC Daresbury Lab
34 # Modified by J. Henrichs, Bureau of Meteorology
35 
36 '''This module contains a transformation that replaces a PSyIR
37 assignment node with its adjoint form.
38 
39 '''
40 from __future__ import absolute_import
41 
42 from psyclone.core import SymbolicMaths
43 from psyclone.psyir.nodes import BinaryOperation, Assignment, Reference, \
44  Literal, UnaryOperation, IntrinsicCall
45 from psyclone.psyir.nodes.array_mixin import ArrayMixin
46 from psyclone.psyir.symbols import REAL_TYPE
47 from psyclone.psyir.transformations import TransformationError
48 
49 from psyclone.psyad.transformations import TangentLinearError
50 from psyclone.psyad.transformations.adjoint_trans import AdjointTransformation
51 
52 # pylint: disable=too-many-locals
53 # pylint: disable=too-many-branches
54 
55 
57  '''Implements a transformation to translate a Tangent-Linear
58  assignment to its Adjoint form.
59 
60  '''
61  def apply(self, node, options=None):
62  '''Apply the Assignment transformation to the specified node. The node
63  must be a valid tangent-linear assignment. The assignment is
64  replaced with its adjoint version.
65 
66  :param node: an Assignment node.
67  :type node: :py:class:`psyclone.psyir.nodes.Assignment`
68  :param options: a dictionary with options for transformations.
69  :type options: Optional[Dict[str, Any]]
70 
71  '''
72  self.validatevalidatevalidate(node, options)
73 
74  # Split the RHS of the assignment into [-]<term> +- <term> +- ...
75  rhs_terms = self._split_nodes_split_nodes(
76  node.rhs, [BinaryOperation.Operator.ADD,
77  BinaryOperation.Operator.SUB])
78 
79  deferred_inc = []
80  sym_maths = SymbolicMaths.get()
81  for rhs_term in rhs_terms:
82 
83  # Find the active var in rhs_term if one exists (we may
84  # find 0.0), storing it in 'active_var' and if so replace
85  # it with lhs_active_var storing the modified term in
86  # 'new_rhs_term'. Also determine whether this is an
87  # increment, storing the result in 'increment'.
88  increment = False
89  active_var = None
90  new_rhs_term = rhs_term.copy()
91  for ref in new_rhs_term.walk(Reference):
92  if ref.symbol not in self._active_variables_active_variables:
93  continue
94  active_var = ref
95  # Identify whether this reference on the RHS matches the
96  # one on the LHS - if so we have an increment.
97  if node.is_array_assignment and isinstance(ref, ArrayMixin):
98  # TODO #1537 - we can't just do `sym_maths.equal` if we
99  # have an array range because the SymbolicMaths class does
100  # not currently support them.
101  # Since we have already checked (in validate) that any
102  # references to the same symbol on the RHS have the same
103  # range, this is an increment if the symbols match.
104  if node.lhs.symbol is ref.symbol:
105  increment = True
106  else:
107  if sym_maths.equal(ref, node.lhs):
108  increment = True
109  if ref.parent:
110  ref.replace_with(node.lhs.copy())
111  else:
112  new_rhs_term = node.lhs.copy()
113  break
114 
115  # Work out whether the binary operation for this term is a
116  # '+' or a '-' and return it in 'rhs_operator'.
117  rhs_operator = BinaryOperation.Operator.ADD
118  previous = rhs_term
119  candidate = rhs_term.parent
120  while not isinstance(candidate, Assignment):
121  if (isinstance(candidate, BinaryOperation) and
122  candidate.operator == BinaryOperation.Operator.SUB and
123  candidate.children[1] is previous):
124  # Rules: + + -> +; - - -> +; + - -> -; - + -> -
125  # If the higher level op is + then there is no
126  # change to the existing op. If it is - then
127  # we flip the op i.e. - => + and + => -.
128  if rhs_operator == BinaryOperation.Operator.SUB:
129  rhs_operator = BinaryOperation.Operator.ADD
130  else:
131  rhs_operator = BinaryOperation.Operator.SUB
132  previous = candidate
133  candidate = candidate.parent
134 
135  if not active_var:
136  # This is an expression without an active variable
137  # (which must be 0.0, otherwise validation will have
138  # rejected it). There is therefore nothing to output.
139  continue
140 
141  if increment:
142  # The output of an increment needs to be deferred as
143  # other terms must be completed before the LHS TL
144  # active variable is modified. Save the rhs term
145  # and its associated operator.
146  deferred_inc.append((new_rhs_term, rhs_operator))
147  else:
148  # Output the adjoint for this term
149  rhs = BinaryOperation.create(
150  rhs_operator, active_var.copy(), new_rhs_term)
151  assignment = Assignment.create(active_var.copy(), rhs)
152  node.parent.children.insert(node.position, assignment)
153 
154  if (len(deferred_inc) == 1 and
155  isinstance(deferred_inc[0][0], Reference)):
156  # No need to output anything as the adjoint is A = A.
157  pass
158  elif deferred_inc:
159  # Output the adjoint for all increment terms in a single line.
160  rhs, _ = deferred_inc.pop(0)
161  for term, operator in deferred_inc:
162  rhs = BinaryOperation.create(operator, rhs, term)
163  assignment = Assignment.create(node.lhs.copy(), rhs)
164  node.parent.children.insert(node.position, assignment)
165  else:
166  # The assignment is not an increment. The LHS active
167  # variable needs to be zero'ed.
168  assignment = Assignment.create(
169  node.lhs.copy(), Literal("0.0", REAL_TYPE))
170  node.parent.children.insert(node.position, assignment)
171 
172  # Remove the original node
173  node.detach()
174 
175  def _array_ranges_match(self, assign, active_variable):
176  '''
177  If the supplied assignment is to an array range and the supplied
178  active variable is the entity being assigned to then this routine
179  checks that the array ranges of the LHS and the supplied reference
180  match. If they do not then an exception is raised.
181 
182  :param assign: the assignment that we are checking.
183  :type assign: :py:class:`psyclone.psyir.nodes.Assignment`
184  :param active_variable: an active variable that appears on the \
185  LHS and RHS of the supplied assignment.
186  :type active_variable: :py:class:`psyclone.psyir.nodes.Reference`
187 
188  :raises TangentLinearError: if the supplied assignment is to a \
189  symbol with an array range but the same symbol occurs on the \
190  RHS without an array range.
191  :raises NotImplementedError: if the array ranges on the LHS and \
192  RHS of the assignment for the supplied variable do not match.
193 
194  '''
195  # This check only needs to proceed if the assignment is to an array
196  # range and the supplied active variable is the one being assigned to.
197  if not (assign.is_array_assignment and active_variable.symbol is
198  assign.lhs.symbol):
199  return
200 
201  if not isinstance(active_variable, ArrayMixin):
202  raise TangentLinearError(
203  f"Assignment is to an array range but found a "
204  f"reference to the LHS variable "
205  f"'{assign.lhs.symbol.name}' without array notation"
206  f" on the RHS: '{assign.debug_string()}'")
207 
208  sym_maths = SymbolicMaths.get()
209 
210  for pos, idx in enumerate(active_variable.indices):
211  lhs_idx = assign.lhs.indices[pos]
212  # TODO #1537. This is a workaround until the SymbolicMaths
213  # class supports the comparison of array ranges.
214  # pylint: disable=unidiomatic-typecheck
215  if not (type(idx) is type(lhs_idx) and
216  sym_maths.equal(idx.start,
217  lhs_idx.start) and
218  sym_maths.equal(idx.stop,
219  lhs_idx.stop) and
220  sym_maths.equal(idx.step,
221  lhs_idx.step)):
222  raise NotImplementedError(
223  f"Different sections of the same active array "
224  f"'{assign.lhs.symbol.name}' are "
225  f"accessed on the LHS and RHS of an assignment: "
226  f"'{assign.debug_string()}'. This is not supported.")
227 
228  def validate(self, node, options=None):
229  '''Perform various checks to ensure that it is valid to apply the
230  AssignmentTrans transformation to the supplied PSyIR Node.
231 
232  :param node: the node that is being checked.
233  :type node: :py:class:`psyclone.psyir.nodes.Assignment`
234  :param options: a dictionary with options for transformations.
235  :type options: Optional[Dict[str, Any]]
236 
237  :raises TransformationError: if the node argument is not an \
238  Assignment.
239  :raises TangentLinearError: if the assignment does not conform \
240  to the required tangent-linear structure.
241 
242  '''
243  # Check node argument is an assignment node
244  if not isinstance(node, Assignment):
245  raise TransformationError(
246  f"Node argument in assignment transformation should be a "
247  f"PSyIR Assignment, but found '{type(node).__name__}'.")
248  assign = node
249 
250  # If there are no active variables then return
251  assignment_active_var_names = [
252  var.name for var in assign.walk(Reference)
253  if var.symbol in self._active_variables_active_variables]
254  if not assignment_active_var_names:
255  # No active variables in this assignment so the assignment
256  # remains unchanged.
257  return
258 
259  # The lhs of the assignment node should be an active variable
260  if assign.lhs.symbol not in self._active_variables_active_variables:
261  # There are active vars on RHS but not on LHS
262  raise TangentLinearError(
263  f"Assignment node '{assign.debug_string()}' has the following "
264  f"active variables on its RHS '{assignment_active_var_names}' "
265  f"but its LHS '{assign.lhs.name}' is not an active variable.")
266 
267  # Split the RHS of the assignment into <expr> +- <expr> +- <expr>
268  rhs_terms = self._split_nodes_split_nodes(
269  assign.rhs, [BinaryOperation.Operator.ADD,
270  BinaryOperation.Operator.SUB])
271 
272  # Check for the special case where RHS=0.0. This is really a
273  # representation of multiplying an active variable by zero but
274  # this is obviously not visible in the code. Use 'float' to
275  # normalise different representations of 0.
276  if (len(rhs_terms) == 1 and isinstance(rhs_terms[0], Literal) and
277  float(rhs_terms[0].value) == 0.0):
278  return
279 
280  # Check each expression term. It must be in the form
281  # A */ <expr> where A is an active variable.
282  for rhs_term in rhs_terms:
283 
284  # When searching for references to an active variable we must
285  # take care to exclude those cases where they are present as
286  # arguments to the L/UBOUND intrinsics (as they will be when
287  # array notation is used).
288  active_vars = []
289  lu_bound_ops = [IntrinsicCall.Intrinsic.LBOUND,
290  IntrinsicCall.Intrinsic.UBOUND]
291  for ref in rhs_term.walk(Reference):
292  if (ref.symbol in self._active_variables_active_variables and
293  not (isinstance(ref.parent, IntrinsicCall) and
294  ref.parent.intrinsic in lu_bound_ops)):
295  active_vars.append(ref)
296 
297  if not active_vars:
298  # This term must contain an active variable
299  raise TangentLinearError(
300  f"Each non-zero term on the RHS of the assignment "
301  f"'{assign.debug_string()}' must have an active variable "
302  f"but '{rhs_term.debug_string()}' does not.")
303 
304  if len(active_vars) > 1:
305  # This term can only contain one active variable
306  raise TangentLinearError(
307  f"Each term on the RHS of the assignment "
308  f"'{assign.debug_string()}' must not have more than one "
309  f"active variable but '{rhs_term.debug_string()}' has "
310  f"{len(active_vars)}.")
311 
312  if (isinstance(rhs_term, Reference) and rhs_term.symbol
313  in self._active_variables_active_variables):
314  self._array_ranges_match_array_ranges_match(assign, rhs_term)
315  # This term consists of a single active variable (with
316  # a multiplier of unity) and is therefore valid.
317  continue
318 
319  # Ignore unary minus if it is the parent. unary minus does
320  # not cause a problem when applying the transformation but
321  # does cause a problem here in the validation when
322  # splitting the term into expressions.
323  if (isinstance(rhs_term, UnaryOperation) and
324  rhs_term.operator ==
325  UnaryOperation.Operator.MINUS):
326  rhs_term = rhs_term.children[0]
327 
328  # Split the term into <expr> */ <expr> */ <expr>
329  expr_terms = self._split_nodes_split_nodes(
330  rhs_term, [BinaryOperation.Operator.MUL,
331  BinaryOperation.Operator.DIV])
332 
333  # One of the expression terms must be an active variable
334  # or an active variable with a preceding + or -.
335  for expr_term in expr_terms:
336  check_term = expr_term
337  if (isinstance(expr_term, UnaryOperation) and
338  expr_term.operator in [UnaryOperation.Operator.PLUS,
339  UnaryOperation.Operator.MINUS]):
340  check_term = expr_term.children[0]
341  if (isinstance(check_term, Reference) and
342  check_term.symbol in self._active_variables_active_variables):
343  active_variable = check_term
344  break
345  else:
346  raise TangentLinearError(
347  f"Each term on the RHS of the assignment "
348  f"'{assign.debug_string()}' must be linear with respect "
349  f"to the active variable, but found "
350  f"'{rhs_term.debug_string()}'.")
351 
352  # The term must be a product of an active variable with an
353  # inactive expression. Check that the active variable does
354  # not appear in a denominator.
355  candidate = active_variable
356  parent = candidate.parent
357  while not isinstance(parent, Assignment):
358  # Starting with the active variable reference, look up
359  # the tree for an ancestor divide operation until
360  # reaching the assignment node.
361  if (isinstance(parent, BinaryOperation) and
362  parent.operator == BinaryOperation.Operator.DIV and
363  parent.children[1] is candidate):
364  # Found a divide and the active variable is on its RHS
365  raise TangentLinearError(
366  f"In tangent-linear code an active variable cannot "
367  f"appear as a denominator but "
368  f"'{rhs_term.debug_string()}' was found in "
369  f"'{assign.debug_string()}'.")
370  # Continue up the PSyIR tree
371  candidate = parent
372  parent = candidate.parent
373 
374  # If the LHS of the assignment is an array range then we only
375  # support accesses of the same variable on the RHS if they have
376  # the same range.
377  self._array_ranges_match_array_ranges_match(assign, active_variable)
378 
379  @staticmethod
380  def _split_nodes(node, binary_operator_list):
381  '''Utility to split an expression into a series of sub-expressions
382  separated by one of the binary operators specified in
383  binary_operator_list.
384 
385  :param node: the node containing the expression to split.
386  :type node: :py:class:`psyclone.psyir.nodes.DataNode`
387  :param binary_operator_list: list of binary operators.
388  :type binary_operator_list: list of
389  :py:class:`psyclone.psyir.nodes.BinaryOperations.Operator`
390 
391  :returns: a list of sub-expressions.
392  :rtype: list of :py:class:`psyclone.psyir.nodes.DataNode`
393 
394  '''
395  if (isinstance(node, BinaryOperation)) and \
396  (node.operator in binary_operator_list):
397  lhs_node_list = AssignmentTrans._split_nodes(
398  node.children[0], binary_operator_list)
399  rhs_node_list = AssignmentTrans._split_nodes(
400  node.children[1], binary_operator_list)
401  return lhs_node_list + rhs_node_list
402  return [node]
403 
404  def __str__(self):
405  return "Convert a tangent-linear PSyIR Assignment to its adjoint form"
406 
407  @property
408  def name(self):
409  '''
410  :returns: the name of the transformation as a string.
411  :rtype: str
412 
413  '''
414  return type(self).__name__
415 
416 
417 # =============================================================================
418 # Documentation utils: The list of module members that we wish AutoAPI to
419 # generate documentation for (see https://psyclone-ref.readthedocs.io).
420 __all__ = ["AssignmentTrans"]
def validate(self, node, options=None)
Definition: psyGen.py:2799