我已在 TensorFlow 中从头编写了 DETR 对象检测管道。
DETR:Kaggle Notebook 链接:包含全部代码;请复制笔记本以重现问题
我已经测试了管道中的所有独立组件,它们都能正常工作。
但当我开始在我的数据集(以 tf.data.Dataset
形式)上进行训练时,
部分代码链接,该处引发错误
出现以下错误:
ValueError: 在用户代码中:
文件 "/tmp/ipykernel_19/4115406382.py",第 7 行,训练步骤函数中 *
y_pred = matcher(y_train, y_pred)
文件 "/tmp/ipykernel_19/968499204.py",第 64 行,在 __call__ 函数中 *
class_prob, bbox_pred = Matcher.match(class_true, bbox_true, class_prob, bbox_pred)
文件 "/tmp/ipykernel_19/968499204.py",第 53 行,在 match 函数中 *
C = Matcher.batched_cost_matrix(class_true, bbox_true, class_prob, bbox_pred)
文件 "/tmp/ipykernel_19/968499204.py",第 46 行,在 batched_cost_matrix 函数中 *
tf.range(tf.shape(class_true)[0]), fn_output_signature=tf.float32
文件 "/tmp/ipykernel_19/968499204.py",第 22 行,在 compute_cost_matrix 函数中 *
N = tf.shape(class_true)[0]
ValueError: 维度 0 的切片索引 0 超出范围。
节点 '{{node map/while/strided_slice_4}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](map/while/Shape, map/while/strided_slice_4/stack, map/while/strided_slice_4/stack_1, map/while/strided_slice_4/stack_2)'
输入形状为:[0], [1], [1], [1],计算得到的输入张量:input[1] = <0>, input[2] = <1>, input[3] = <1>。
当我使用 tf.shape
打印 class_true
的形状时,得到的是 <Tensor("Shape_2:0", shape=(1,), dtype=int32)>
,我不太明白这个结果。
而使用 Tensor.shape
时,返回的形状包含 None
,因此再次出现错误。
但在单独测试(非训练模式)时,我能得到正确的 class_true
形状,即 tf.Tensor([42], shape=(1,), dtype=int32)
。
我应该怎么修复他 ?