Framework (contrib)
[TOC]
Framework utilities.
tf.contrib.framework.assert_same_float_dtype(tensors=None, dtype=None)
Validate and return float type based on tensors
and dtype
.
For ops such as matrix multiplication, inputs and weights must be of the
same float type. This function validates that all tensors
are the same type,
validates that type is dtype
(if supplied), and returns the type. Type must
be dtypes.float32
or dtypes.float64
. If neither tensors
nor
dtype
is supplied, default to dtypes.float32
.
Args:
tensors
: Tensors of input values. Can includeNone
elements, which will be ignored.dtype
: Expected type.
Returns:
Validated type.
Raises:
ValueError
: if neithertensors
nordtype
is supplied, or result is not float.
tf.contrib.framework.assert_scalar_int(tensor)
Assert tensor
is 0-D, of type tf.int32
or tf.int64
.
Args:
tensor
: Tensor to test.
Returns:
tensor
, for chaining.
Raises:
ValueError
: iftensor
is not 0-D, of typetf.int32
ortf.int64
.
tf.contrib.framework.convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None, as_ref=False)
Converts value to a SparseTensor
or Tensor
.
Args:
value
: ASparseTensor
,SparseTensorValue
, or an object whose type has a registeredTensor
conversion function.dtype
: Optional element type for the returned tensor. If missing, the type is inferred from the type ofvalue
.name
: Optional name to use if a newTensor
is created.as_ref
: True if we want the result as a ref tensor. Only used if a newTensor
is created.
Returns:
A SparseTensor
or Tensor
based on value
.
Raises:
RuntimeError
: If result type is incompatible withdtype
.
tf.contrib.framework.get_graph_from_inputs(op_input_list, graph=None)
Returns the appropriate graph to use for the given inputs.
- If
graph
is provided, we validate that all inputs inop_input_list
are from the same graph. - Otherwise, we attempt to select a graph from the first Operation- or
Tensor-valued input in
op_input_list
, and validate that all other such inputs are in the same graph. - If the graph was not specified and it could not be inferred from
op_input_list
, we attempt to use the default graph.
Args:
op_input_list
: A list of inputs to an operation, which may includeTensor
,Operation
, and other objects that may be converted to a graph element.graph
: (Optional) The explicit graph to use.
Raises:
TypeError
: Ifop_input_list
is not a list or tuple, or if graph is not a Graph.ValueError
: If a graph is explicitly passed and not all inputs are from it, or if the inputs are from multiple graphs, or we could not find a graph and there was no default graph.
Returns:
The appropriate graph to use for the given inputs.
tf.is_numeric_tensor(tensor)
tf.is_non_decreasing(x, name=None)
Returns True
if x
is non-decreasing.
Elements of x
are compared in row-major order. The tensor [x[0],...]
is non-decreasing if for every adjacent pair we have x[i] <= x[i+1]
.
If x
has less than two elements, it is trivially non-decreasing.
See also: is_strictly_increasing
Args:
x
: NumericTensor
.name
: A name for this operation (optional). Defaults to "is_non_decreasing"
Returns:
Boolean Tensor
, equal to True
iff x
is non-decreasing.
Raises:
TypeError
: ifx
is not a numeric tensor.
tf.is_strictly_increasing(x, name=None)
Returns True
if x
is strictly increasing.
Elements of x
are compared in row-major order. The tensor [x[0],...]
is strictly increasing if for every adjacent pair we have x[i] < x[i+1]
.
If x
has less than two elements, it is trivially strictly increasing.
See also: is_non_decreasing
Args:
x
: NumericTensor
.name
: A name for this operation (optional). Defaults to "is_strictly_increasing"
Returns:
Boolean Tensor
, equal to True
iff x
is strictly increasing.
Raises:
TypeError
: ifx
is not a numeric tensor.
tf.contrib.framework.is_tensor(x)
Check for tensor types.
Check whether an object is a tensor. Equivalent to
isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])
.
Args:
x
: An python object to check.
Returns:
True
if x
is a tensor, False
if not.
tf.contrib.framework.reduce_sum_n(tensors, name=None)
Reduce tensors to a scalar sum.
This reduces each tensor in tensors
to a scalar via tf.reduce_sum
, then
adds them via tf.add_n
.
Args:
tensors
: List of tensors, all of the same numeric type.name
: Tensor name, and scope for all other ops.
Returns:
Total loss tensor, or None if no losses have been configured.
Raises:
ValueError
: iflosses
is missing or empty.
tf.contrib.framework.with_shape(expected_shape, tensor)
Asserts tensor has expected shape.
If tensor shape and expected_shape, are fully defined, assert they match. Otherwise, add assert op that will validate the shape when tensor is evaluated, and set shape on tensor.
Args:
expected_shape
: Expected shape to assert, as a 1D array of ints, or tensor of same.tensor
: Tensor whose shape we're validating.
Returns:
tensor, perhaps with a dependent assert operation.
Raises:
ValueError
: if tensor has an invalid shape.
tf.contrib.framework.with_same_shape(expected_tensor, tensor)
Assert tensors are the same shape, from the same graph.
Args:
expected_tensor
: Tensor with expected shape.tensor
: Tensor of actual values.
Returns:
Tuple of (actual_tensor, label_tensor), possibly with assert ops added.
Deprecation
tf.contrib.framework.deprecated(date, instructions)
Decorator for marking functions or methods deprecated.
This decorator logs a deprecation warning whenever the decorated function is called. It has the following format:
It also edits the docstring of the function: ' (deprecated)' is appended to the first line of the docstring and a deprecation notice is prepended to the rest of the docstring.
Args:
date
: String. The date the function is scheduled to be removed. Must be ISO 8601 (YYYY-MM-DD).instructions
: String. Instructions on how to update code using the deprecated function.
Returns:
Decorated function or method.
Raises:
ValueError
: If date is not in ISO 8601 format, or instructions are empty.
tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names)
Decorator for marking specific function arguments as deprecated.
This decorator logs a deprecation warning whenever the decorated function is called with the deprecated argument. It has the following format:
Calling
It also edits the docstring of the function: ' (deprecated arguments)' is appended to the first line of the docstring and a deprecation notice is prepended to the rest of the docstring.
Args:
date
: String. The date the function is scheduled to be removed. Must be ISO 8601 (YYYY-MM-DD).instructions
: String. Instructions on how to update code using the deprecated function.*deprecated_arg_names
: String. The deprecated arguments.
Returns:
Decorated function or method.
Raises:
ValueError
: If date is not in ISO 8601 format, instructions are empty, or the deprecated arguments are not present in the function signature.
tf.contrib.framework.deprecated_arg_values(date, instructions, **deprecated_kwargs)
Decorator for marking specific function argument values as deprecated.
This decorator logs a deprecation warning whenever the decorated function is called with the deprecated argument values. It has the following format:
Calling
It also edits the docstring of the function: ' (deprecated arguments)' is appended to the first line of the docstring and a deprecation notice is prepended to the rest of the docstring.
Args:
date
: String. The date the function is scheduled to be removed. Must be ISO 8601 (YYYY-MM-DD).instructions
: String. Instructions on how to update code using the deprecated function.**deprecated_kwargs
: The deprecated argument values.
Returns:
Decorated function or method.
Raises:
ValueError
: If date is not in ISO 8601 format, or instructions are empty.
Arg_Scope
tf.contrib.framework.arg_scope(list_ops_or_scope, **kwargs)
Stores the default arguments for the given set of list_ops.
For usage, please see examples at top of the file.
Args:
list_ops_or_scope
: List or tuple of operations to set argument scope for or a dictionary containing the current scope. When list_ops_or_scope is a dict, kwargs must be empty. When list_ops_or_scope is a list or tuple, then every op in it need to be decorated with @add_arg_scope to work.**kwargs
: keyword=value that will define the defaults for each op inlist_ops. All the ops need to accept the given set of arguments.
Yields:
the current_scope, which is a dictionary of {op: {arg: value}}
Raises:
TypeError
: if list_ops is not a list or a tuple.ValueError
: if any op in list_ops has not be decorated with @add_arg_scope.
tf.contrib.framework.add_arg_scope(func)
Decorates a function with args so it can be used within an arg_scope.
Args:
func
: function to decorate.
Returns:
A tuple with the decorated function func_with_args().
tf.contrib.framework.has_arg_scope(func)
Checks whether a func has been decorated with @add_arg_scope or not.
Args:
func
: function to check.
Returns:
a boolean.
tf.contrib.framework.arg_scoped_arguments(func)
Returns the list kwargs that arg_scope can set for a func.
Args:
func
: function which has been decorated with @add_arg_scope.
Returns:
a list of kwargs names.
Variables
tf.contrib.framework.add_model_variable(var)
Adds a variable to the GraphKeys.MODEL_VARIABLES
collection.
Args:
var
: a variable.
tf.train.assert_global_step(global_step_tensor)
Asserts global_step_tensor
is a scalar int Variable
or Tensor
.
Args:
global_step_tensor
:Tensor
to test.
tf.contrib.framework.assert_or_get_global_step(graph=None, global_step_tensor=None)
Verifies that a global step tensor is valid or gets one if None is given.
If global_step_tensor
is not None, check that it is a valid global step
tensor (using assert_global_step
). Otherwise find a global step tensor using
get_global_step
and return it.
Args:
graph
: The graph to find the global step tensor for.global_step_tensor
: The tensor to check for suitability as a global step. If None is given (the default), find a global step tensor.
Returns:
A tensor suitable as a global step, or None
if none was provided and none
was found.
tf.contrib.framework.assign_from_checkpoint(model_path, var_list)
Creates an operation to assign specific variables from a checkpoint.
Args:
model_path
: The full path to the model checkpoint. To get latest checkpoint usemodel_path = tf.train.latest_checkpoint(checkpoint_dir)
var_list
: A list ofVariable
objects or a dictionary mapping names in the checkpoint to the corresponding variables to initialize. If empty or None, it would return no_op(), None.
Returns:
the restore_op and the feed_dict that need to be run to restore var_list.
Raises:
ValueError
: If the checkpoint specified atmodel_path
is missing one of the variables invar_list
.
tf.contrib.framework.assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, reshape_variables=False)
Returns a function that assigns specific variables from a checkpoint.
Args:
model_path
: The full path to the model checkpoint. To get latest checkpoint usemodel_path = tf.train.latest_checkpoint(checkpoint_dir)
var_list
: A list ofVariable
objects or a dictionary mapping names in the checkpoint to the correspoing variables to initialize. If empty or None, it would return no_op(), None.ignore_missing_vars
: Boolean, if True it would ignore variables missing in the checkpoint with a warning instead of failing.reshape_variables
: Boolean, if True it would automatically reshape variables which are of different shape then the ones stored in the checkpoint but which have the same number of elements.
Returns:
A function that takes a single argument, a tf.Session
, that applies the
assignment operation.
Raises:
ValueError
: If the checkpoint specified atmodel_path
is missing one of the variables invar_list
.
tf.contrib.framework.assign_from_values(var_names_to_values)
Creates an assignment operation from a given mapping.
This function provides a mechanism for performing assignment of variables to values in a way that does not fill the graph with large assignment values.
Args:
var_names_to_values
: A map from variable names to values.
Returns:
assign_op
: AnOperation
that assigns each of the given variables to the requested values.feed_dict
: The feed dictionary to use when evaluatingassign_op
.
Raises:
ValueError
: if any of the given variable names were not found.
tf.contrib.framework.assign_from_values_fn(var_names_to_values)
Returns a function that assigns specific variables from the given values.
This function provides a mechanism for performing assignment of variables to values in a way that does not fill the graph with large assignment values.
Args:
var_names_to_values
: A map from variable names to values.
Returns:
A function that takes a single argument, a tf.Session
, that applies the
assignment operation.
Raises:
ValueError
: if any of the given variable names were not found.
tf.contrib.framework.create_global_step(graph=None)
Create global step tensor in graph.
Args:
graph
: The graph in which to create the global step. If missing, use default graph.
Returns:
Global step tensor.
Raises:
ValueError
: if global step key is already defined.
tf.train.get_global_step(graph=None)
Get the global step tensor.
The global step tensor must be an integer variable. We first try to find it
in the collection GLOBAL_STEP
, or by name global_step:0
.
Args:
graph
: The graph to find the global step in. If missing, use default graph.
Returns:
The global step variable, or None
if none was found.
Raises:
TypeError
: If the global step tensor has a non-integer type, or if it is not aVariable
.
tf.contrib.framework.get_or_create_global_step(graph=None)
Returns and create (if necessary) the global step variable.
Args:
graph
: The graph in which to create the global step. If missing, use default graph.
Returns:
the tensor representing the global step variable.
tf.contrib.framework.get_local_variables(scope=None, suffix=None)
Gets the list of model variables, filtered by scope and/or suffix.
Args:
scope
: an optional scope for filtering the variables to return.suffix
: an optional suffix for filtering the variables to return.
Returns:
a list of variables in collection with scope and suffix.
tf.contrib.framework.get_model_variables(scope=None, suffix=None)
Gets the list of model variables, filtered by scope and/or suffix.
Args:
scope
: an optional scope for filtering the variables to return.suffix
: an optional suffix for filtering the variables to return.
Returns:
a list of variables in collection with scope and suffix.
tf.contrib.framework.get_unique_variable(var_op_name)
Gets the variable uniquely identified by that var_op_name.
Args:
var_op_name
: the full name of the variable op, including the scope.
Returns:
a tensorflow variable.
Raises:
ValueError
: if no variable uniquely identified by the name exists.
tf.contrib.framework.get_variables_by_name(given_name, scope=None)
Gets the list of variables that were given that name.
Args:
given_name
: name given to the variable without any scope.scope
: an optional scope for filtering the variables to return.
Returns:
a copied list of variables with the given name and scope.
tf.contrib.framework.get_variables_by_suffix(suffix, scope=None)
Gets the list of variables that end with the given suffix.
Args:
suffix
: suffix for filtering the variables to return.scope
: an optional scope for filtering the variables to return.
Returns:
a copied list of variables with the given name and prefix.
tf.contrib.framework.get_variables_to_restore(include=None, exclude=None)
Gets the list of the variables to restore.
Args:
include
: an optional list/tuple of scope strings for filtering which variables from the VARIABLES collection to include. None would include all the variables.exclude
: an optional list/tuple of scope strings for filtering which variables from the VARIABLES collection to exclude. None it would not exclude any.
Returns:
a list of variables to restore.
Raises:
TypeError
: include or exclude is provided but is not a list or a tuple.
tf.contrib.framework.get_variables(scope=None, suffix=None, collection='variables')
Gets the list of variables, filtered by scope and/or suffix.
Args:
scope
: an optional scope for filtering the variables to return.suffix
: an optional suffix for filtering the variables to return.collection
: in which collection search for. Defaults toGraphKeys.VARIABLES
.
Returns:
a list of variables in collection with scope and suffix.
tf.contrib.framework.local_variable(initial_value, validate_shape=True, name=None)
Create variable and add it to GraphKeys.LOCAL_VARIABLES
collection.
Args:
initial_value
: See variables.Variable.init.validate_shape
: See variables.Variable.init.name
: See variables.Variable.init.
Returns:
New variable.
tf.contrib.framework.model_variable(*args, **kwargs)
Gets an existing model variable with these parameters or creates a new one.
Args:
name
: the name of the new or existing variable.shape
: shape of the new or existing variable.dtype
: type of the new or existing variable (defaults toDT_FLOAT
).initializer
: initializer for the variable if one is created.regularizer
: a (Tensor -> Tensor or None) function; the result of applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.trainable
: IfTrue
also add the variable to the graph collectionGraphKeys.TRAINABLE_VARIABLES
(seetf.Variable
).collections
: A list of collection names to which the Variable will be added. Note that the variable is always also added to theGraphKeys.VARIABLES
andGraphKeys.MODEL_VARIABLES
collections.caching_device
: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device.device
: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable.
Returns:
The created or existing variable.
tf.contrib.framework.variable(*args, **kwargs)
Gets an existing variable with these parameters or creates a new one.
Args:
name
: the name of the new or existing variable.shape
: shape of the new or existing variable.dtype
: type of the new or existing variable (defaults toDT_FLOAT
).initializer
: initializer for the variable if one is created.regularizer
: a (Tensor -> Tensor or None) function; the result of applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.trainable
: IfTrue
also add the variable to the graph collectionGraphKeys.TRAINABLE_VARIABLES
(seetf.Variable
).collections
: A list of collection names to which the Variable will be added. If None it would default totf.GraphKeys.VARIABLES
.caching_device
: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device.device
: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable.
Returns:
The created or existing variable.
class tf.contrib.framework.VariableDeviceChooser
Device chooser for variables.
When using a parameter server it will assign them in a round-robin fashion. When not using a parameter server it allows GPU or CPU placement.
tf.contrib.framework.VariableDeviceChooser.__call__(op)
tf.contrib.framework.VariableDeviceChooser.__init__(num_tasks=0, job_name='ps', device_type='CPU', device_index=0)
Initialize VariableDeviceChooser.
Usage:
To use with 2 parameter servers: VariableDeviceChooser(2)
To use without parameter servers: VariableDeviceChooser() VariableDeviceChooser(device_type='GPU') # For GPU placement
Args:
num_tasks
: number of tasks.job_name
: String, a name for the parameter server job.device_type
: Optional device type string (e.g. "CPU" or "GPU")device_index
: int. Optional device index. If left unspecified, device represents 'any' device_index.
tf.contrib.framework.zero_initializer(ref, use_locking=True, name='zero_initializer')
Initialize 'ref' with all zeros, ref tensor should be uninitialized. If already initialized, you will get ValueError. This op is intended to save memory during initialization.
Args:
ref
: ref of the tensor need to be zero initialized.name
: optional name for this operation.
Returns:
ref that initialized.
Raises:
ValueError
: If ref tensor is initialized.