使用GridSearch训练RN时出错:断言失败:[0][Op:Assert]名称:eartVariableNameReuse

python keras scikit-learn data-science tensorflow2.0 匿名 | 2020-08-01 20:15:05


我正在用keras和scikit训练一个数据集,然后使用KerasClassifier和GridSearch来训练模型。所有这些都能正常工作,但当我更新Nvidia CUDA和一些库与之兼容时,它停止工作。
我使用的是Nvidia GPU GeForce GTX 960M。
Tensorflow GPU版本:2.1.0
Keras:2.4.0
Scikit Learn:0.23.1
错误如下:
---> 15             grid_result = grid.fit(X_train_normalizado, y_train,verbose=0,callbacks=[es]) #Early Termination
16 if(debug):
17 print("EJECUCIÓN Nº {} COMPLETADA".format(cont))
~\anaconda3\envs\gpu\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
71 FutureWarning)
72 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 73 return f(**kwargs)
74 return inner_f
75
~\anaconda3\envs\gpu\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
763 refit_start_time = time.time()
764 if y is not None:
--> 765 self.best_estimator_.fit(X, y, **fit_params)
766 else:
767 self.best_estimator_.fit(X, **fit_params)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\wrappers\scikit_learn.py in fit(self, x, y, **kwargs)
221 raise ValueError('Invalid shape for y: ' + str(y.shape))
222 self.n_classes_ = len(self.classes_)
--> 223 return super(KerasClassifier, self).fit(x, y, **kwargs)
224
225 def predict(self, x, **kwargs):
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\wrappers\scikit_learn.py in fit(self, x, y, **kwargs)
155 **self.filter_sk_params(self.build_fn.__call__))
156 else:
--> 157 self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
158
159 if (losses.is_categorical_crossentropy(self.model.loss) and
in create_model3(optimizer, init_mode, dropout_rate, weight_constraint, actNum, neurNum, nCapas)
8
9 # create model
---> 10 model = Sequential()
11 #n_Inicial = len(caracteristicas.columns)
12 n_Inicial = 168 #Cambiar para meterlo como un parámetro del modelo
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\training\tracking\base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\sequential.py in __init__(self, layers, name)
115 # Skip the init in FunctionalModel since model doesn't have input/output yet
116 super(functional.Functional, self).__init__( # pylint: disable=bad-super-call
--> 117 name=name, autocast=False)
118 self.supports_masking = True
119 self._compute_output_and_mask_jointly = True
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\training\tracking\base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in __init__(self, *args, **kwargs)
306 self._steps_per_execution = None
307
--> 308 self._init_batch_counters()
309 self._base_model_initialized = True
310 _keras_api_gauge.get_cell('model').set(True)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\training\tracking\base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in _init_batch_counters(self)
315 # `evaluate`, and `predict`.
316 agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA
--> 317 self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg)
318 self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg)
319 self._predict_counter = variables.Variable(
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py in __call__(cls, *args, **kwargs)
260 return cls._variable_v1_call(*args, **kwargs)
261 elif cls is Variable:
--> 262 return cls._variable_v2_call(*args, **kwargs)
263 else:
264 return super(VariableMetaclass, cls).__call__(*args, **kwargs)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py in _variable_v2_call(cls, initial_value, trainable, validate_shape, caching_device, name, variable_def, dtype, import_scope, constraint, synchronization, aggregation, shape)
254 synchronization=synchronization,
255 aggregation=aggregation,
--> 256 shape=shape)
257
258 def __call__(cls, *args, **kwargs):
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py in (**kws)
235 shape=None):
236 """Call on Variable class. Useful to force the signature."""
--> 237 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
238 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
239 previous_getter = _make_getter(getter, previous_getter)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variable_scope.py in default_variable_creator_v2(next_creator, **kwargs)
2644 synchronization=synchronization,
2645 aggregation=aggregation,
-> 2646 shape=shape)
2647
2648
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py in __call__(cls, *args, **kwargs)
262 return cls._variable_v2_call(*args, **kwargs)
263 else:
--> 264 return super(VariableMetaclass, cls).__call__(*args, **kwargs)
265
266
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\resource_variable_ops.py in __init__(self, initial_value, trainable, collections, validate_shape, caching_device, name, dtype, variable_def, import_scope, constraint, distribute_strategy, synchronization, aggregation, shape)
1516 aggregation=aggregation,
1517 shape=shape,
-> 1518 distribute_strategy=distribute_strategy)
1519
1520 def _init_from_args(self,
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\resource_variable_ops.py in _init_from_args(self, initial_value, trainable, collections, caching_device, name, dtype, constraint, synchronization, aggregation, distribute_strategy, shape)
1664 shared_name=shared_name,
1665 name=name,
-> 1666 graph_mode=self._in_graph_mode)
1667 # pylint: disable=protected-access
1668 if (self._in_graph_mode and initial_value is not None and
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\resource_variable_ops.py in eager_safe_variable_handle(initial_value, shape, shared_name, name, graph_mode)
241 dtype = initial_value.dtype.base_dtype
242 return _variable_handle_from_shape_and_dtype(
--> 243 shape, dtype, shared_name, name, graph_mode, initial_value)
244
245
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\resource_variable_ops.py in _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name, graph_mode, initial_value)
173 # support string tensors, we encode the assertion string in the Op name
174 gen_logging_ops._assert( # pylint: disable=protected-access
--> 175 math_ops.logical_not(exists), [exists], name="EagerVariableNameReuse")
176
177 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\gen_logging_ops.py in _assert(condition, data, summarize, name)
47 return _result
48 except _core._NotOkStatusException as e:
---> 49 _ops.raise_from_not_ok_status(e, name)
50 except _core._FallbackException:
51 pass
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\ops.py in raise_from_not_ok_status(e, name)
6841 message = e.message + (" name: " + name if name is not None else "")
6842 # pylint: disable=protected-access
-> 6843 six.raise_from(core._status_to_exception(e.code, message), None)
6844 # pylint: enable=protected-access
6845
~\anaconda3\envs\gpu\lib\site-packages\six.py in raise_from(value, from_value)
InvalidArgumentError: assertion failed: [0] [Op:Assert] name: EagerVariableNameReuse```






0 答案



World is powered by solitude
备案号:湘ICP备19012068号