Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
ad466ee
Initial checkin
JimClarke5 Jan 1, 2021
092b47d
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
4887b5b
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
04eeea6
JavaDoc cleanup
JimClarke5 Jan 1, 2021
dcb2414
Javadoc fixes
JimClarke5 Jan 3, 2021
82f18bf
Change LossInterface to LossMetric.
JimClarke5 Jan 5, 2021
9aa1511
Removed hashmap for variables, they are not needed as the variables o…
JimClarke5 Jan 7, 2021
1097722
reformat code
JimClarke5 Jan 7, 2021
bc0f468
Add tests for assertBroadcastable
JimClarke5 Jan 11, 2021
41876d5
Change type to resultType
JimClarke5 Jan 11, 2021
61af528
Added V data type for sampleWeights so that it is not forced to be th…
JimClarke5 Jan 11, 2021
c121c07
change 'type' to 'resultType'
JimClarke5 Jan 11, 2021
e9ee98f
clean up mean and fix assert assertBroadcastable
JimClarke5 Jan 11, 2021
9788983
fix error message
JimClarke5 Jan 11, 2021
8857a66
Change sampleWeights to have its own generic type <S extends TNumber>
JimClarke5 Jan 12, 2021
34a779f
Add commment about invalid tests expecting IllegalArgumentExceptions
JimClarke5 Jan 12, 2021
748f16d
Add this exception instead of the more generic IllegalArgumentExcepti…
JimClarke5 Jan 12, 2021
212541b
change IllegalArgumentException to NotBroadcastableException.
JimClarke5 Jan 12, 2021
f0d72d2
reformat code
JimClarke5 Jan 12, 2021
8b49c60
Fis=x Javadoc
JimClarke5 Jan 13, 2021
20c6e98
Fix Reduce to use boradcastWeights,
JimClarke5 Jan 17, 2021
d3d7ee9
Added comment to count to indicate that it may be weighted.
JimClarke5 Jan 17, 2021
fe86b0b
Added SetsOps and fixed AssertBroadcastable to use SetsOps methods,
JimClarke5 Jan 19, 2021
0edd114
Fixed based on various PR comments.
JimClarke5 Jan 19, 2021
7d78fd3
Deleted, no longer needed after change to Variable handling in Metrics.
JimClarke5 Jan 19, 2021
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
af1b49f
Nicer error messages for mode-forbidden ops (#169)
rnett Jan 2, 2021
7732601
Initialization imprvements (#178)
rnett Jan 7, 2021
a737334
Clairify tensorOf lifetime requirements (#190)
rnett Jan 19, 2021
253cc73
Remove extra generics from op generation (#193)
rnett Jan 26, 2021
22cb5b2
Add Java 11 support - Initial Phase (#185)
JimClarke5 Jan 26, 2021
4d1aa20
Update manual ops for new codegen (#196)
rnett Jan 26, 2021
2b7f6ed
Fix Losses to use CHANNELS_FIRST/LAST for CategoricalCrossentropy
JimClarke5 Jan 20, 2021
3800b71
Fix SetOps to properly convert sparse tensor to dense tensor using tf…
JimClarke5 Jan 30, 2021
3045999
Initial checkin
JimClarke5 Jan 1, 2021
9eb5adf
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
187c17c
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
050fe28
JavaDoc cleanup
JimClarke5 Jan 1, 2021
b640406
Javadoc fixes
JimClarke5 Jan 3, 2021
3715513
Change LossInterface to LossMetric.
JimClarke5 Jan 5, 2021
a1c1976
Removed hashmap for variables, they are not needed as the variables o…
JimClarke5 Jan 7, 2021
6641fca
reformat code
JimClarke5 Jan 7, 2021
fa76043
Add tests for assertBroadcastable
JimClarke5 Jan 11, 2021
e136f4d
Change type to resultType
JimClarke5 Jan 11, 2021
e00f2ef
Added V data type for sampleWeights so that it is not forced to be th…
JimClarke5 Jan 11, 2021
bc6c64b
change 'type' to 'resultType'
JimClarke5 Jan 11, 2021
02da963
clean up mean and fix assert assertBroadcastable
JimClarke5 Jan 11, 2021
44cdc35
fix error message
JimClarke5 Jan 11, 2021
49370b9
Change sampleWeights to have its own generic type <S extends TNumber>
JimClarke5 Jan 12, 2021
24b4125
Add commment about invalid tests expecting IllegalArgumentExceptions
JimClarke5 Jan 12, 2021
43c6b7b
Add this exception instead of the more generic IllegalArgumentExcepti…
JimClarke5 Jan 12, 2021
78e9dab
change IllegalArgumentException to NotBroadcastableException.
JimClarke5 Jan 12, 2021
5508969
reformat code
JimClarke5 Jan 12, 2021
c662524
Fis=x Javadoc
JimClarke5 Jan 13, 2021
512a153
Fix Reduce to use boradcastWeights,
JimClarke5 Jan 17, 2021
0663c3c
Added comment to count to indicate that it may be weighted.
JimClarke5 Jan 17, 2021
122e06b
Added SetsOps and fixed AssertBroadcastable to use SetsOps methods,
JimClarke5 Jan 19, 2021
b7b14b1
Fixed based on various PR comments.
JimClarke5 Jan 19, 2021
13639d3
Deleted, no longer needed after change to Variable handling in Metrics.
JimClarke5 Jan 19, 2021
561322f
Fix Losses to use CHANNELS_FIRST/LAST for CategoricalCrossentropy
JimClarke5 Jan 20, 2021
2a13012
Fix SetOps to properly convert sparse tensor to dense tensor using tf…
JimClarke5 Jan 30, 2021
36f3a69
Merge remote-tracking branch 'upstream/metrics1' into metrics1
JimClarke5 Jan 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
clean up mean and fix assert assertBroadcastable
  • Loading branch information
JimClarke5 committed Jan 30, 2021
commit 02da963993fcfd20eb6b2718f67ca468808c4e79
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TIntegral;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;

Expand Down Expand Up @@ -91,7 +92,8 @@ public static <U extends TNumber> Op assertBroadcastable(
}

for (int i = 0; i < valuesRankStatic; i++) {
if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) {
if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we want to handle the case where either or both of these sizes are unknown?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure if it can happen. Whenever I try to create a shape with -1 in a dim, the system seems to fill this in when I fill the object with data. This is true with numpy and TF Java. At any rate to be safe, how about?

Change
if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) {

to

if(!weightsShapeStatic.isUnknown() && !valuesShapeStatic.isUnknown() &&
        !weightsShapeStatic.hasUnknownDimension() & !valuesShapeStatic.hasUnknownDimension()) {

If it fails this, then it falls into dynamic checks

Copy link
Contributor

@deansher deansher Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

&& weightsShapeStatic.size(i) != 1) {
throw new IllegalArgumentException(
String.format(
"%s Mismatch at dim %d. values.shape=%s weights.shape=%s.",
Expand Down Expand Up @@ -152,7 +154,7 @@ private static <T extends TNumber> Operand<TBool> hasValidNonscalarShape(
*
* @param tf the TensorFlow Ops
* @param weightsShape the operand for the shape of the sample weights
* @param valuesShape the operand for the shape of the sample weights
* @param valuesShape the operand for the shape of the values
* @param <T> the data type for the operands
* @return a boolean operand to determine if the shapes have valid dimensions or not.
*/
Expand All @@ -163,7 +165,7 @@ private static <T extends TNumber> Operand<TBool> hasValidDims(
Operand<T> validDims =
tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1));
SetDiff1d<T, TInt32> invalidDimsDiff =
tf.setDiff1d(tf.shape.flatten(valuesShape2d), tf.shape.flatten(validDims));
tf.setDiff1d(weightsShape, tf.shape.flatten(validDims));
Operand<T> invalidDims = invalidDimsDiff.out();
Operand<TInt32> numInvalidDims = tf.size(invalidDims);
return tf.math.equal(tf.constant(0), numInvalidDims);
Expand All @@ -178,9 +180,10 @@ private static <T extends TNumber> Operand<TBool> hasValidDims(
* @param tf the TensorFlow Ops
* @param x the Operand used to calculate the mean
* @param <T> the type of the Operand.
* @param <Z> the data type for the result
* @return the mean of the operand
*/
public static <T extends TType> Operand<T> mean(Ops tf, Operand<T> x) {
public static <T extends TType, Z extends TNumber> Operand<Z> mean(Ops tf, Operand<T> x) {
return mean(tf, x, null, false);
}

Expand All @@ -190,58 +193,63 @@ public static <T extends TType> Operand<T> mean(Ops tf, Operand<T> x) {
*
* @param tf the TensorFlow Ops
* @param x the Operand used to calculate the mean
* @param axis Axes to compute the mean.
* @param axes Axes to compute the mean.
* @param <T> the type of the Operand.
* @param <U> the type of the axis.
* @return the mean of the operand, alongside the specified axis.
* @param <U> the type of the axes.
* @param <Z> the data type for the result
* @return the mean of the operand, along the specified axes.
*/
public static <T extends TType, U extends TNumber> Operand<T> mean(
Ops tf, Operand<T> x, Operand<U> axis) {
return mean(tf, x, axis, false);
public static <T extends TType, U extends TIntegral, Z extends TNumber> Operand<Z> mean(
Ops tf, Operand<T> x, Operand<U> axes) {
return mean(tf, x, axes, false);
}

/**
* Calculate the mean of the operand, along all axis.
* Calculate the mean of the operand, along all axes.
*
* @param tf the TensorFlow Ops
* @param x the Operand used to calculate the mean
* @param keepDims Indicates whether to keep the dimensions or not. If <code>keepdims</code> is
* <code>false</code>, the rank of the tensor is reduced by 1 for each entry in <code>axis
* <code>false</code>, the rank of the tensor is reduced by 1 for each entry in <code>axes
* </code>. If <code>keepdims</code> is <code>true</code>, the reduced dimensions are retained
* with length 1.
* @param <T> the type of the operand
* @param <Z> the data type for the result
* @return the mean of elements of <code>x</code>.
*/
public static <T extends TType> Operand<T> mean(Ops tf, Operand<T> x, boolean keepDims) {
public static <T extends TType, Z extends TNumber> Operand<Z> mean(
Ops tf, Operand<T> x, boolean keepDims) {
return mean(tf, x, null, keepDims);
}

/**
* Calculate the mean of the operand, alongside the specified axis.
* Calculate the mean of the operand, alongside the specified axes.
*
* @param tf the TensorFlow Ops
* @param x the Operand used to calculate the mean
* @param axis Axes to compute the mean.
* @param axes Axes to compute the mean.
* @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the
* * rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is `true`, the
* * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the
* * reduced dimensions are retained with length 1.
* @param <T> the data type of the Operand
* @param <U> the data type of the axis
* @param <U> the data type of the axes
* @param <Z> the data type for the result
* @return the mean of elements of `x`.
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public static <T extends TType, U extends TNumber> Operand<T> mean(
Ops tf, Operand<T> x, Operand<U> axis, boolean keepDims) {
public static <T extends TType, U extends TIntegral, Z extends TNumber> Operand<Z> mean(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has multiple mismatches between compile-time types and runtime types, interwoven with uses of the Z type parameter:

  • The method signature requires the caller to specify Z.
  • The method makes a runtime decision on what runtime type to store in Operand<Z> xf. There's no guarantee that the runtime type will match Z.
  • If x is a TBool, the method tf.dtypes.casts a TFloat32 tensor to a TBool vector (is that supported?) and the returned tensor type is guaranteed to be different from the use of Z in Operand<Z> xf.

Do I understand correctly that this method is intended to return an Operand<T> unless T is TBool, in which case it intends to return an Operand<TFloat32>? If so, then I see no way to express the method's intent in Java's compile-time type system. I see two possibilities:

  • Perhaps implement separate methods for T extends TNumber versus TBool and deal with the difference at the point of call?
  • Or perhaps return an Operand<? extends TNumber>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I added booleanMean methods taking Operand<TBool> as input, and returning Operand<TFloat64>.

<U extends TIntegral> Operand<TFloat64> booleanMean(
          Ops tf, Operand<TBool> x, Operand<U> axes, boolean keepDims)

I changed all the regular mean operations to and took out the ambiguous .

<T extends TNumber, U extends TIntegral> Operand<T> mean(
      Ops tf, Operand<T> x, Operand<U> axes, boolean keepDims) 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done -- I raised one more specific issue on the new code.

Ops tf, Operand<T> x, Operand<U> axes, boolean keepDims) {
// Cannot use generics here because xf may change from TBool to TFloat32
Operand xf;
if (x.asOutput().type() == TBool.class) {
xf = tf.dtypes.cast(x, TFloat32.class);
Operand<Z> xf;
if (x.type().equals(TBool.class)) {
xf = (Operand<Z>) tf.dtypes.cast(x, TFloat32.class);
} else {
xf = x;
xf = (Operand<Z>) x;
}
if (axis == null) {
axis = allAxes(tf, xf);
if (axes == null) {
axes = (Operand<U>) allAxes(tf, xf);
}
return tf.math.mean(xf, axis, Mean.keepDims(keepDims));
Operand theMean = tf.math.mean(xf, axes, Mean.keepDims(keepDims));
return x.type().equals(TBool.class) ? tf.dtypes.cast(theMean, TBool.class) : theMean;
}
}