Reference Guide  2.5.0
metadata_to_arguments_rules.py
1 # -----------------------------------------------------------------------------
2 # BSD 3-Clause License
3 #
4 # Copyright (c) 2023-2024, Science and Technology Facilities Council.
5 # All rights reserved.
6 #
7 # Redistribution and use in source and binary forms, with or without
8 # modification, are permitted provided that the following conditions are met:
9 #
10 # * Redistributions of source code must retain the above copyright notice, this
11 # list of conditions and the following disclaimer.
12 #
13 # * Redistributions in binary form must reproduce the above copyright notice,
14 # this list of conditions and the following disclaimer in the documentation
15 # and/or other materials provided with the distribution.
16 #
17 # * Neither the name of the copyright holder nor the names of its
18 # contributors may be used to endorse or promote products derived from
19 # this software without specific prior written permission.
20 #
21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24 # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25 # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27 # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31 # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 # POSSIBILITY OF SUCH DAMAGE.
33 # -----------------------------------------------------------------------------
34 # Author: R. W. Ford, STFC Daresbury Lab
35 
36 '''This module implements a class that encapsulates rules that map
37  LFRic kernel metadata to kernel arguments.
38 
39 '''
40 from collections import OrderedDict
41 import re
42 
43 from psyclone.domain.lfric import LFRicConstants
44 from psyclone.domain.lfric.kernel import (
45  OperatorArgMetadata, ColumnwiseOperatorArgMetadata, FieldArgMetadata,
46  FieldVectorArgMetadata, InterGridArgMetadata, InterGridVectorArgMetadata,
47  ScalarArgMetadata)
48 from psyclone.errors import InternalError
49 
50 
51 # pylint: disable=too-few-public-methods
53  '''This class encapsulates rules to map LFRic kernel metadata to
54  kernel arguments. It does this by calling class methods each, of
55  which represent a particular kernel argument or set of
56  arguments. It calls these in the order that the arguments should
57  be found in the kernel metadata. The particular methods called and
58  their ordering is determined by the supplied kernel metadata.
59 
60  Kernel argument information from kernel metadata can be used for
61  more than one thing, e.g. to create or check arguments within a
62  kernel and their declarations (using PSyIR), create the arguments
63  to a kernel call from the generated PSy-layer code or to create
64  appropriate algorithm PSyIR for an Invoke of the kernel or the
65  resulting call to the PSy-layer routine. Subclasses of this class
66  can be implemented for these different requirements.
67 
68  '''
69  _metadata = None
70  _info = None
71  # Regex used to identify the special 'enforce_bc_code' kernel that
72  # applies boundary conditions to a field. Allows for the renaming
73  # performed by PSyclone when performing kernel transformations.
74  # TODO #487 - this can be removed when we have metadata to specify
75  # that a kernel applies boundary conditinos.
76  bc_kern_regex = re.compile(r"enforce_bc_(\d+_)?code", flags=re.I)
77 
78  @classmethod
79  def mapping(cls, metadata, info=None):
80  '''Takes kernel metadata as input and returns whatever is added to the
81  _info variable. This class adds nothing to the _info variable,
82  it is up to the subclass to do this. The variable is treated
83  as a container. The optional info argument allows the subclass
84  to add to an existing object if required.
85 
86  :param metadata: the kernel metadata.
87  :type metadata: \
88  py:class:`psyclone.domain.lfric.kernel.LFRicKernelMetadata`
89  :param info: optional object to initialise the _info \
90  variable. Defaults to None.
91  :type info: :py:class:`Object`
92 
93  '''
94  cls._initialise_initialise(info)
95  cls._metadata_metadata = metadata
96  cls._generate_generate()
97  return cls._info_info
98 
99  @classmethod
100  def _initialise(cls, info):
101  '''Initialise the _info variable. This is implemented as its own
102  method to allow for simple subclassing (i.e. the mapping
103  method should not need to be subclassed).
104 
105  :param info: object to initialise the _info variable.
106  :type info: :py:class:`Object`
107 
108  '''
109  cls._info_info = info
110 
111  @classmethod
112  def _cell_position(cls):
113  '''A cell position argument.'''
114 
115  @classmethod
116  def _mesh_height(cls):
117  '''A mesh height argument.'''
118 
119  @classmethod
120  def _mesh_ncell2d_no_halos(cls):
121  '''Argument providing the number of columns in the mesh ignoring
122  halos.
123 
124  '''
125 
126  @classmethod
127  def _mesh_ncell2d(cls):
128  '''Argument providing the number of columns in the mesh including
129  halos.
130 
131  '''
132 
133  @classmethod
134  def _cell_map(cls):
135  '''Arguments providing a mapping from coarse to fine mesh for the
136  current column.
137 
138  '''
139 
140  @classmethod
141  def _scalar(cls, meta_arg):
142  '''Argument providing an LFRic scalar value.
143 
144  :param meta_arg: the metadata associated with this scalar argument.
145  :type meta_arg: \
146  :py:class:`psyclone.domain.lfric.kernel.ScalarArgMetadata`
147 
148  '''
149 
150  @classmethod
151  def _field(cls, meta_arg):
152  '''Argument providing an LFRic field.
153 
154  :param meta_arg: the metadata associated with this field argument.
155  :type meta_arg: \
156  :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
157 
158  '''
159 
160  @classmethod
161  def _field_vector(cls, meta_arg):
162  '''Arguments providing the components of an LFRic field vector.
163 
164  :param meta_arg: the metadata associated with this field \
165  vector argument.
166  :type meta_arg: \
167  :py:class:`psyclone.domain.lfric.kernel.FieldVectorArgMetadata`
168 
169  '''
170 
171  @classmethod
172  def _operator(cls, meta_arg):
173  '''Arguments providing an LMA operator.
174 
175  :param meta_arg: the metadata associated with the operator \
176  arguments.
177  :type meta_arg: \
178  :py:class:`psyclone.domain.lfric.kernel.OperatorArgMetadata`
179 
180  '''
181 
182  @classmethod
183  def _cma_operator(cls, meta_arg):
184  '''Arguments providing a columnwise operator.
185 
186  :param meta_arg: the metadata associated with the CMA operator \
187  arguments.
188  :type meta_arg: :py:class:`psyclone.domain.lfric.kernel.\
189  ColumnwiseOperatorArgMetadata`
190 
191  '''
192 
193  @classmethod
194  def _ref_element_properties(cls, meta_ref_element):
195  '''Arguments required if there are reference-element properties
196  specified in the metadata.
197 
198  :param meta_ref_element: the metadata capturing the \
199  reference-element properties required by the kernel.
200  :type meta_mesh: List[\
201  :py:class:`psyclone.domain.lfric.kernel.MetaRefElementArgMetadata`]
202 
203  '''
204 
205  @classmethod
206  def _mesh_properties(cls, meta_mesh):
207  '''All arguments required for mesh properties specified in the kernel
208  metadata.
209 
210  :param meta_mesh: the metadata capturing the mesh properties \
211  required by the kernel.
212  :type meta_mesh: List[\
213  :py:class:`psyclone.domain.lfric.kernel.MetaMeshArgMetadata`]
214 
215  '''
216 
217  @classmethod
218  def _fs_common(cls, function_space):
219  '''Arguments associated with a function space that are common to
220  fields and operators.
221 
222  :param str function_space: the current function space.
223 
224  '''
225 
226  @classmethod
227  def _fs_compulsory_field(cls, function_space):
228  '''Compulsory arguments for this function space.
229 
230  :param str function_space: the current function space.
231 
232  '''
233 
234  @classmethod
235  def _fs_intergrid(cls, meta_arg):
236  '''Function-space related arguments for an intergrid kernel.
237 
238  :param meta_arg: the metadata capturing the InterGrid argument \
239  required by the kernel.
240  :type meta_arg: \
241  :py:class:`psyclone.domain.lfric.kernel.InterGridArgMetadata`
242 
243  '''
244 
245  @classmethod
246  def _basis(cls, function_space):
247  '''Arguments associated with basis functions on the supplied function
248  space.
249 
250  :param str function_space: the current function space.
251 
252  '''
253 
254  @classmethod
255  def _diff_basis(cls, function_space):
256  '''Arguments associated with differential basis functions on the
257  supplied function space.
258 
259  :param str function_space: the current function space.
260 
261  '''
262 
263  @classmethod
264  def _quad_rule(cls, shapes):
265  '''Quadrature information is required (gh_shape =
266  gh_quadrature_*). Shape information is provided for each shape
267  in the order specified in the gh_shape metadata.
268 
269  :param shapes: the metadata capturing the quadrature shapes \
270  required by the kernel.
271  :type shapes: List[str]
272 
273  '''
274 
275  @classmethod
276  def _field_bcs_kernel(cls):
277  '''Fix for the field boundary condition kernel.'''
278 
279  @classmethod
280  def _operator_bcs_kernel(cls):
281  '''Fix for the operator boundary condition kernel.'''
282 
283  @classmethod
284  def _stencil_cross2d_extent(cls, meta_arg):
285  '''The field has a stencil access of type 'cross2d' of unknown extent
286  and therefore requires the extent to be passed from the
287  algorithm layer.
288 
289  :param meta_arg: the metadata associated with a field argument \
290  with a cross2d stencil access.
291  :type meta_arg: \
292  :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
293 
294  '''
295 
296  @classmethod
297  def _stencil_cross2d_max_extent(cls, meta_arg):
298  '''The field has a stencil access of type 'cross2d' and requires the
299  maximum size of a stencil extent to be passed from the
300  algorithm layer.
301 
302  :param meta_arg: the metadata associated with a field argument \
303  with a cross2d stencil access.
304  :type meta_arg: \
305  :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
306 
307  '''
308 
309  @classmethod
310  def _stencil_extent(cls, meta_arg):
311  '''The field has a stencil access (that is not of type 'cross2d') of
312  unknown extent and therefore requires the extent to be passed
313  from the algorithm layer.
314 
315  :param meta_arg: the metadata associated with a field argument \
316  with a stencil access.
317  :type meta_arg: \
318  :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
319 
320  '''
321 
322  @classmethod
323  def _stencil_xory1d_direction(cls, meta_arg):
324  '''The field has a stencil access of type 'xory1d' and therefore
325  requires the stencil direction to be passed from the algorithm
326  layer.
327 
328  :param meta_arg: the metadata associated with a field argument \
329  with a xory1d stencil access.
330  :type meta_arg: \
331  :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
332 
333  '''
334 
335  @classmethod
336  def _stencil_cross2d(cls, meta_arg):
337  '''Stencil information that is always passed from the algorithm layer
338  if a field has a stencil access of type 'cross2d'.
339 
340  :param meta_arg: the metadata associated with a field argument \
341  with a stencil access.
342  :type meta_arg: \
343  :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
344 
345  '''
346 
347  @classmethod
348  def _stencil(cls, meta_arg):
349  '''Stencil information that is always passed from the algorithm layer
350  if a field has a stencil access that is not of type 'cross2d'.
351 
352  :param meta_arg: the metadata associated with a field argument \
353  with a stencil access.
354  :type meta_arg: \
355  :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
356 
357  '''
358 
359  @classmethod
360  def _banded_dofmap(cls, function_space, cma_operator):
361  '''Adds a banded dofmap for the provided function space and cma
362  operator when there is an cma assembly kernel.
363 
364  :param str function_space: the function space for this banded \
365  dofmap.
366  :param cma_operator: the cma operator metadata associated with \
367  this banded dofmap.
368  :type cma_operator: :py:class:`psyclone.domain.lfric.kernel.\
369  ColumnwiseOperatorArgMetadata`
370 
371  '''
372 
373  @classmethod
374  def _indirection_dofmap(cls, function_space, cma_operator):
375  '''Adds an indirection dofmap for the provided function space and cma
376  operator when there is an apply cma kernel.
377 
378  :param str function_space: the function space for this \
379  indirection dofmap.
380  :param cma_operator: the cma operator metadata associated with \
381  this indirection dofmap.
382  :type cma_operator: :py:class:`psyclone.domain.lfric.kernel.\
383  ColumnwiseOperatorArgMetadata`
384 
385  '''
386 
387  # pylint: disable=too-many-branches
388  # pylint: disable=too-many-statements
389  @classmethod
390  def _generate(cls):
391  '''Specifies which arguments appear in an argument list and their
392  ordering. Calls methods for each type of argument. These
393  methods can be specialised by a subclass for its particular
394  need.
395 
396  :raises InternalError: if an unexpected mesh property is found.
397 
398  '''
399  # pylint: disable=unidiomatic-typecheck
400  # All operator types require the cell index to be provided
401  if cls._metadata.meta_args_get(
402  [OperatorArgMetadata, ColumnwiseOperatorArgMetadata]):
403  cls._cell_position()
404 
405  # Pass the number of layers in the mesh unless this kernel is
406  # applying a CMA operator or doing a CMA matrix-matrix calculation
407  if cls._metadata.kernel_type not in ["cma-apply", "cma-matrix-matrix"]:
408  cls._mesh_height()
409 
410  # Pass the number of cells in the mesh if this kernel has a
411  # LMA operator argument.
412  # TODO issue #2074 this call should be used to replace the
413  # code that currently includes ncell3d for *every* operator it
414  # encounters (in _operator()).
415  # if cls._metadata.meta_args_get(OperatorArgMetadata):
416  # cls._mesh_ncell3d()
417 
418  # Pass the number of columns in the mesh if this kernel operates on
419  # the 'domain' or has a CMA operator argument. For the former we
420  # exclude halo columns.
421  if cls._metadata.operates_on == "domain":
422  cls._mesh_ncell2d_no_halos()
423  if cls._metadata.meta_args_get(ColumnwiseOperatorArgMetadata):
424  cls._mesh_ncell2d()
425 
426  if cls._metadata.kernel_type == "inter-grid":
427  # Inter-grid kernels require special arguments: the
428  # cell-map for the current column providing the mapping
429  # from the coarse to the fine mesh.
430  cls._cell_map()
431 
432  # For each argument in the order they are specified in the
433  # kernel metadata, call particular methods depending on what
434  # type of argument we find (field, field vector, operator or
435  # scalar). If the argument is a field or field vector and also
436  # has a stencil access then also call appropriate stencil
437  # methods.
438  const = LFRicConstants()
439  for meta_arg in cls._metadata.meta_args:
440 
441  if type(meta_arg) in [
442  FieldArgMetadata, FieldVectorArgMetadata,
443  InterGridArgMetadata, InterGridVectorArgMetadata]:
444  if type(meta_arg) in [FieldArgMetadata, InterGridArgMetadata]:
445  cls._field(meta_arg)
446  if type(meta_arg) in [
447  FieldVectorArgMetadata, InterGridVectorArgMetadata]:
448  cls._field_vector(meta_arg)
449  if meta_arg.stencil:
450  if meta_arg.stencil == "cross2d":
451  # stencil extent is not provided in the
452  # metadata so must be passed from the Algorithm
453  # layer.
454  cls._stencil_cross2d_extent(meta_arg)
455  # Due to the nature of the stencil extent array
456  # the max size of a stencil branch must be passed
457  # from the Algorithm layer.
458  cls._stencil_cross2d_max_extent(meta_arg)
459  else:
460  # stencil extent is not provided in the
461  # metadata so must be passed from the Algorithm
462  # layer.
463  cls._stencil_extent(meta_arg)
464  if meta_arg.stencil == "xory1d":
465  # if xory1d is specified then the actual
466  # direction must be passed from the Algorithm layer.
467  cls._stencil_xory1d_direction(meta_arg)
468  # stencil information that is always passed from the
469  # Algorithm layer.
470  if meta_arg.stencil == "cross2d":
471  cls._stencil_cross2d(meta_arg)
472  else:
473  cls._stencil(meta_arg)
474  elif type(meta_arg) is OperatorArgMetadata:
475  cls._operator(meta_arg)
476  elif type(meta_arg) is ColumnwiseOperatorArgMetadata:
477  cls._cma_operator(meta_arg)
478  elif type(meta_arg) is ScalarArgMetadata:
479  cls._scalar(meta_arg)
480  else:
481  raise InternalError(
482  f"Unexpected meta_arg type '{type(meta_arg).__name__}' "
483  f"found.")
484 
485  # For each unique function space (in the order they appear in the
486  # metadata arguments)
487  function_space_args = cls._metadata.meta_args_get(
488  [FieldArgMetadata, FieldVectorArgMetadata,
489  InterGridArgMetadata, InterGridVectorArgMetadata,
490  OperatorArgMetadata, ColumnwiseOperatorArgMetadata])
491  unique_function_spaces = OrderedDict()
492  for arg in function_space_args:
493  if type(arg) in [
494  OperatorArgMetadata, ColumnwiseOperatorArgMetadata]:
495  unique_function_spaces[arg.function_space_to] = None
496  unique_function_spaces[arg.function_space_from] = None
497  else:
498  unique_function_spaces[arg.function_space] = None
499 
500  for function_space in unique_function_spaces.keys():
501  # Provide function-space-specific arguments common to
502  # fields and LMA operators unless this is an inter-grid or
503  # CMA matrix-matrix kernel.
504  if cls._metadata.kernel_type not in [
505  "cma-matrix-matrix", "inter-grid"]:
506  cls._fs_common(function_space)
507 
508  # Provide additional arguments if there is a field or
509  # field vector on this space
510  if (cls._metadata.field_meta_args_on_fs(
511  [FieldArgMetadata, FieldVectorArgMetadata],
512  function_space)):
513  cls._fs_compulsory_field(function_space)
514 
515  # Provide additional arguments if there is a intergrid
516  # field or intergrid vector field on this space
517  intergrid_field = cls._metadata.field_meta_args_on_fs(
518  [InterGridArgMetadata, InterGridVectorArgMetadata],
519  function_space)
520  if intergrid_field:
521  cls._fs_intergrid(intergrid_field[0])
522 
523  cma_ops = cls._metadata.operator_meta_args_on_fs(
524  ColumnwiseOperatorArgMetadata, function_space)
525  if cma_ops:
526  if cls._metadata.kernel_type == "cma-assembly":
527  # CMA-assembly requires banded dofmaps
528  cls._banded_dofmap(function_space, cma_ops[0])
529  elif cls._metadata.kernel_type == "cma-apply":
530  # Applying a CMA operator requires indirection dofmaps
531  cls._indirection_dofmap(
532  function_space, cma_ops[0])
533 
534  # Provide any optional arguments. These arguments are
535  # associated with the keyword arguments (basis function
536  # and differential basis function) for a function space.
537  meta_funcs = cls._metadata.meta_funcs \
538  if cls._metadata.meta_funcs else []
539  if any(func for func in meta_funcs if func.basis_function and
540  func.function_space == function_space):
541  cls._basis(function_space)
542  if any(func for func in meta_funcs if func.diff_basis_function
543  and func.function_space == function_space):
544  cls._diff_basis(function_space)
545 
546  # The boundary condition kernel (enforce_bc_kernel) is a
547  # special case.
548  if (cls._metadata.procedure_name and
549  cls.bc_kern_regex.match(cls._metadata.procedure_name)):
550  cls._field_bcs_kernel()
551 
552  # The operator boundary condition kernel
553  # (enforce_operator_bc_kernel) is a special case.
554  if (cls._metadata.procedure_name and
555  cls._metadata.procedure_name.lower() ==
556  "enforce_operator_bc_code"):
557  cls._operator_bcs_kernel()
558 
559  # Reference-element properties
560  if cls._metadata.meta_ref_element:
561  cls._ref_element_properties(cls._metadata.meta_ref_element)
562 
563  # Mesh properties
564  if cls._metadata.meta_mesh:
565  cls._mesh_properties(cls._metadata.meta_mesh)
566 
567  # Quadrature arguments are required if one or more basis or
568  # differential basis functions are used by the kernel and a
569  # quadrature shape is supplied.
570  if cls._metadata.meta_funcs and cls._metadata.shapes and \
571  any(shape for shape in cls._metadata.shapes if shape in
572  const.VALID_QUADRATURE_SHAPES):
573  cls._quad_rule(cls._metadata.shapes)