Source code for absl.testing.flagsaver

# Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Decorator and context manager for saving and restoring flag values.

There are many ways to save and restore.  Always use the most convenient method
for a given use case.

Here are examples of each method.  They all call do_stuff() while FLAGS.someflag
is temporarily set to 'foo'.

  from absl.testing import flagsaver

  # Use a decorator which can optionally override flags via arguments.
  @flagsaver.flagsaver(someflag='foo')
  def some_func():
    do_stuff()

  # Use a decorator which can optionally override flags with flagholders.
  @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23))
  def some_func():
    do_stuff()

  # Use a decorator which does not override flags itself.
  @flagsaver.flagsaver
  def some_func():
    FLAGS.someflag = 'foo'
    do_stuff()

  # Use a context manager which can optionally override flags via arguments.
  with flagsaver.flagsaver(someflag='foo'):
    do_stuff()

  # Save and restore the flag values yourself.
  saved_flag_values = flagsaver.save_flag_values()
  try:
    FLAGS.someflag = 'foo'
    do_stuff()
  finally:
    flagsaver.restore_flag_values(saved_flag_values)

We save and restore a shallow copy of each Flag object's __dict__ attribute.
This preserves all attributes of the flag, such as whether or not it was
overridden from its default value.

WARNING: Currently a flag that is saved and then deleted cannot be restored.  An
exception will be raised.  However if you *add* a flag after saving flag values,
and then restore flag values, the added flag will be deleted with no errors.
"""

import functools
import inspect

from absl import flags

FLAGS = flags.FLAGS


[docs]def flagsaver(*args, **kwargs): """The main flagsaver interface. See module doc for usage.""" if not args: return _FlagOverrider(**kwargs) # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)` if len(args) == 1 and callable(args[0]): if kwargs: raise ValueError( "It's invalid to specify both positional and keyword parameters.") func = args[0] if inspect.isclass(func): raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') return _wrap(func, {}) # args can be a list of (FlagHolder, value) pairs. # In which case they augment any specified kwargs. for arg in args: if not isinstance(arg, tuple) or len(arg) != 2: raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) holder, value = arg if not isinstance(holder, flags.FlagHolder): raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) if holder.name in kwargs: raise ValueError('Cannot set --%s multiple times' % holder.name) kwargs[holder.name] = value return _FlagOverrider(**kwargs)
[docs]def save_flag_values(flag_values=FLAGS): """Returns copy of flag values as a dict. Args: flag_values: FlagValues, the FlagValues instance with which the flag will be saved. This should almost never need to be overridden. Returns: Dictionary mapping keys to values. Keys are flag names, values are corresponding __dict__ members. E.g. {'key': value_dict, ...}. """ return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
[docs]def restore_flag_values(saved_flag_values, flag_values=FLAGS): """Restores flag values based on the dictionary of flag values. Args: saved_flag_values: {'flag_name': value_dict, ...} flag_values: FlagValues, the FlagValues instance from which the flag will be restored. This should almost never need to be overridden. """ new_flag_names = list(flag_values) for name in new_flag_names: saved = saved_flag_values.get(name) if saved is None: # If __dict__ was not saved delete "new" flag. delattr(flag_values, name) else: if flag_values[name].value != saved['_value']: flag_values[name].value = saved['_value'] # Ensure C++ value is set. flag_values[name].__dict__ = saved
def _wrap(func, overrides): """Creates a wrapper function that saves/restores flag values. Args: func: function object - This will be called between saving flags and restoring flags. overrides: {str: object} - Flag names mapped to their values. These flags will be set after saving the original flag state. Returns: return value from func() """ @functools.wraps(func) def _flagsaver_wrapper(*args, **kwargs): """Wrapper function that saves and restores flags.""" with _FlagOverrider(**overrides): return func(*args, **kwargs) return _flagsaver_wrapper class _FlagOverrider(object): """Overrides flags for the duration of the decorated function call. It also restores all original values of flags after decorated method completes. """ def __init__(self, **overrides): self._overrides = overrides self._saved_flag_values = None def __call__(self, func): if inspect.isclass(func): raise TypeError('flagsaver cannot be applied to a class.') return _wrap(func, self._overrides) def __enter__(self): self._saved_flag_values = save_flag_values(FLAGS) try: FLAGS._set_attributes(**self._overrides) except: # It may fail because of flag validators. restore_flag_values(self._saved_flag_values, FLAGS) raise def __exit__(self, exc_type, exc_value, traceback): restore_flag_values(self._saved_flag_values, FLAGS) def _copy_flag_dict(flag): """Returns a copy of the flag object's __dict__. It's mostly a shallow copy of the __dict__, except it also does a shallow copy of the validator list. Args: flag: flags.Flag, the flag to copy. Returns: A copy of the flag object's __dict__. """ copy = flag.__dict__.copy() copy['_value'] = flag.value # Ensure correct restore for C++ flags. copy['validators'] = list(flag.validators) return copy