-
Notifications
You must be signed in to change notification settings - Fork 222
Metrics Phase 1 #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics Phase 1 #180
Changes from 1 commit
c57a2e7
09fc07e
a99dcb4
ba294ea
04f419a
ad466ee
092b47d
4887b5b
04eeea6
dcb2414
82f18bf
9aa1511
1097722
bc0f468
41876d5
61af528
c121c07
e9ee98f
9788983
8857a66
34a779f
748f16d
212541b
f0d72d2
8b49c60
20c6e98
d3d7ee9
fe86b0b
0edd114
7d78fd3
02e7ebf
af1b49f
7732601
a737334
253cc73
22cb5b2
4d1aa20
2b7f6ed
3800b71
3045999
9eb5adf
187c17c
050fe28
b640406
3715513
a1c1976
6641fca
fa76043
e136f4d
e00f2ef
bc6c64b
02da963
44cdc35
49370b9
24b4125
43c6b7b
78e9dab
5508969
c662524
512a153
0663c3c
122e06b
b7b14b1
13639d3
561322f
2a13012
36f3a69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| import org.tensorflow.types.TBool; | ||
| import org.tensorflow.types.TFloat32; | ||
| import org.tensorflow.types.TInt32; | ||
| import org.tensorflow.types.family.TIntegral; | ||
| import org.tensorflow.types.family.TNumber; | ||
| import org.tensorflow.types.family.TType; | ||
|
|
||
|
|
@@ -91,7 +92,8 @@ public static <U extends TNumber> Op assertBroadcastable( | |
| } | ||
|
|
||
| for (int i = 0; i < valuesRankStatic; i++) { | ||
| if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) { | ||
| if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) | ||
| && weightsShapeStatic.size(i) != 1) { | ||
| throw new IllegalArgumentException( | ||
| String.format( | ||
| "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", | ||
|
|
@@ -152,7 +154,7 @@ private static <T extends TNumber> Operand<TBool> hasValidNonscalarShape( | |
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param weightsShape the operand for the shape of the sample weights | ||
| * @param valuesShape the operand for the shape of the sample weights | ||
| * @param valuesShape the operand for the shape of the values | ||
| * @param <T> the data type for the operands | ||
| * @return a boolean operand to determine if the shapes have valid dimensions or not. | ||
| */ | ||
|
|
@@ -163,7 +165,7 @@ private static <T extends TNumber> Operand<TBool> hasValidDims( | |
| Operand<T> validDims = | ||
| tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); | ||
| SetDiff1d<T, TInt32> invalidDimsDiff = | ||
| tf.setDiff1d(tf.shape.flatten(valuesShape2d), tf.shape.flatten(validDims)); | ||
| tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); | ||
| Operand<T> invalidDims = invalidDimsDiff.out(); | ||
| Operand<TInt32> numInvalidDims = tf.size(invalidDims); | ||
| return tf.math.equal(tf.constant(0), numInvalidDims); | ||
|
|
@@ -178,9 +180,10 @@ private static <T extends TNumber> Operand<TBool> hasValidDims( | |
| * @param tf the TensorFlow Ops | ||
| * @param x the Operand used to calculate the mean | ||
| * @param <T> the type of the Operand. | ||
| * @param <Z> the data type for the result | ||
| * @return the mean of the operand | ||
| */ | ||
| public static <T extends TType> Operand<T> mean(Ops tf, Operand<T> x) { | ||
| public static <T extends TType, Z extends TNumber> Operand<Z> mean(Ops tf, Operand<T> x) { | ||
| return mean(tf, x, null, false); | ||
| } | ||
|
|
||
|
|
@@ -190,58 +193,63 @@ public static <T extends TType> Operand<T> mean(Ops tf, Operand<T> x) { | |
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param x the Operand used to calculate the mean | ||
| * @param axis Axes to compute the mean. | ||
| * @param axes Axes to compute the mean. | ||
| * @param <T> the type of the Operand. | ||
| * @param <U> the type of the axis. | ||
| * @return the mean of the operand, alongside the specified axis. | ||
| * @param <U> the type of the axes. | ||
| * @param <Z> the data type for the result | ||
| * @return the mean of the operand, along the specified axes. | ||
| */ | ||
| public static <T extends TType, U extends TNumber> Operand<T> mean( | ||
| Ops tf, Operand<T> x, Operand<U> axis) { | ||
| return mean(tf, x, axis, false); | ||
| public static <T extends TType, U extends TIntegral, Z extends TNumber> Operand<Z> mean( | ||
| Ops tf, Operand<T> x, Operand<U> axes) { | ||
| return mean(tf, x, axes, false); | ||
| } | ||
|
|
||
| /** | ||
| * Calculate the mean of the operand, along all axis. | ||
| * Calculate the mean of the operand, along all axes. | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param x the Operand used to calculate the mean | ||
| * @param keepDims Indicates whether to keep the dimensions or not. If <code>keepdims</code> is | ||
| * <code>false</code>, the rank of the tensor is reduced by 1 for each entry in <code>axis | ||
| * <code>false</code>, the rank of the tensor is reduced by 1 for each entry in <code>axes | ||
| * </code>. If <code>keepdims</code> is <code>true</code>, the reduced dimensions are retained | ||
| * with length 1. | ||
| * @param <T> the type of the operand | ||
| * @param <Z> the data type for the result | ||
| * @return the mean of elements of <code>x</code>. | ||
| */ | ||
| public static <T extends TType> Operand<T> mean(Ops tf, Operand<T> x, boolean keepDims) { | ||
| public static <T extends TType, Z extends TNumber> Operand<Z> mean( | ||
| Ops tf, Operand<T> x, boolean keepDims) { | ||
| return mean(tf, x, null, keepDims); | ||
| } | ||
|
|
||
| /** | ||
| * Calculate the mean of the operand, alongside the specified axis. | ||
| * Calculate the mean of the operand, alongside the specified axes. | ||
| * | ||
| * @param tf the TensorFlow Ops | ||
| * @param x the Operand used to calculate the mean | ||
| * @param axis Axes to compute the mean. | ||
| * @param axes Axes to compute the mean. | ||
| * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the | ||
| * * rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is `true`, the | ||
| * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the | ||
| * * reduced dimensions are retained with length 1. | ||
| * @param <T> the data type of the Operand | ||
| * @param <U> the data type of the axis | ||
| * @param <U> the data type of the axes | ||
| * @param <Z> the data type for the result | ||
| * @return the mean of elements of `x`. | ||
| */ | ||
| @SuppressWarnings({"unchecked", "rawtypes"}) | ||
| public static <T extends TType, U extends TNumber> Operand<T> mean( | ||
| Ops tf, Operand<T> x, Operand<U> axis, boolean keepDims) { | ||
| public static <T extends TType, U extends TIntegral, Z extends TNumber> Operand<Z> mean( | ||
|
||
| Ops tf, Operand<T> x, Operand<U> axes, boolean keepDims) { | ||
| // Cannot use generics here because xf may change from TBool to TFloat32 | ||
| Operand xf; | ||
| if (x.asOutput().type() == TBool.class) { | ||
| xf = tf.dtypes.cast(x, TFloat32.class); | ||
| Operand<Z> xf; | ||
| if (x.type().equals(TBool.class)) { | ||
| xf = (Operand<Z>) tf.dtypes.cast(x, TFloat32.class); | ||
| } else { | ||
| xf = x; | ||
| xf = (Operand<Z>) x; | ||
| } | ||
| if (axis == null) { | ||
| axis = allAxes(tf, xf); | ||
| if (axes == null) { | ||
| axes = (Operand<U>) allAxes(tf, xf); | ||
| } | ||
| return tf.math.mean(xf, axis, Mean.keepDims(keepDims)); | ||
| Operand theMean = tf.math.mean(xf, axes, Mean.keepDims(keepDims)); | ||
| return x.type().equals(TBool.class) ? tf.dtypes.cast(theMean, TBool.class) : theMean; | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we want to handle the case where either or both of these sizes are unknown?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not quite sure if it can happen. Whenever I try to create a shape with -1 in a dim, the system seems to fill this in when I fill the object with data. This is true with numpy and TF Java. At any rate to be safe, how about?
Change
if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) {to
If it fails this, then it falls into dynamic checks
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done