COMPUTER VISION

Data augmentation on TPU

Mark Porath
23 min readMay 12, 2021
…vast pipelines where untold dimensions of color space lie in wait…

Model training is an iterative process, but it doesn’t have to be slow. With a tf.data pipeline, TFRecords → TPU, work that lasted hours on CPU is done in no time flat. That’s why Google made TFRecords and TPUs.

Yet challenges remain. Tensorflow and Keras have limited options for data augmentation. Most of us turn to specialized libraries like Sci-kit Image, OpenCV or Albumentations for anything more than a random flip.

Are these libraries even compatible with a TFRecords dataset? or a TPU? We intend to find out.

Table of Contents
-- The challenge: a smarter, leafier crop
-- A color mask
* cv_color_mask
* tf_color_mask 👈
* demo: disappearing racoon
-- A custom crop
* loop_crop
* array_crop
* max_rand_crop 🎲
-- Pipeline integration
* trouble
* tf.py_function 🔑
-- Pipeline comparisons
* on CPU, GPU
* on TPU
* recap 🏁
-- Tensorboard
-- Portability
* graph of tf_color_mask 💯
* graphs of crop functions
-- Conclusion
-- Epilogue: color space brainstorm
* a side of hue
* immunohistochemistry in CIELAB 🔬
* pervasive problem with stains
* fluoro-combo-blender
-- References

The Challenge

For an exemplary pipeline, check out Kaggle’s starter notebook in its popular contest, Cassava Leaf Disease Classification. Attached to the notebook is a set of TFRecords, TensorFlow’s recommended source for an efficient data stream, especially with big, unstructured data.

This blue sky and puffy clouds are considered a healthy leaf?

The cassava jpegs from 2020 are big — 600x800x3 — and are framed every which way: extreme close-ups, distant views, upward angles of sky, downward angles of shadow and dirt. Foliage might fill the frame or just skirt along the sides. If the goal is to classify the leaves as healthy or not, a resize or center crop might leave something to be desired.

Researchers at UCLA and around the globe have practice segmenting digital leafy content. According to their studies, the requisite color space is not RGB, but HSV (hue-saturation-value). Across camera types and lighting conditions, a hue threshold can segment plant material consistently.

OpenCV is extremely popular for this sort of task. No doubt you have seen demonstrations of the invisibility cloak? Most of those projects use OpenCV to convert an image from RGB to HSV, then mask the cloak’s hue.

Similarly, we will convert from RGB to HSV, but only momentarily, in order to evaluate coordinates for a leafy crop. With those coordinates, we will crop the RGB tensor to remove extraneous info before the model’s input layer.

Ignore non-leafy hues to find the right spot. Now that’s a healthy leaf!

Of course, if we crop the images offline, before writing to TFRecords, OpenCV will operate independently of the tensor pipeline. But that approach really limits our options. We certainly can’t create and upload new TFRecords files between epochs. Thus, any preprocessing function that could augment the training dataset in some randomized or dynamic way must take effect in the pipeline.

Our plan is to retrofit the pipeline from within decode_image, rather than later in data_augment. In other words, we intend to impact training and validation sets, prepping individual tensor images as they emerge from our GCS bucket, before any batching. How to use our functions at inference — even beyond a Kaggle kernel — is a discussion we save for the end.

Color mask: OpenCV vs TensorFlow

Our first function relies on cv2 for the color filtering. It starts and ends with a TensorFlow op, tf.cast, to handle tensors in and out. The default values for HSV arguments bracket mostly green hues.

The tensor passes to cv2 as a numpy array.

Two adjustments distinguish our tensor-friendly function from the typical color segmentation snippit from OpenCV. First, our cv_color_mask starts from a tensor. Because OpenCV accepts tensors only as numpy arrays, we pass the tensor to cv2 ops with .numpy() in lines 4 and 6.

Second, our masking routine starts RGB, whereas most cv2 example code starts BGR. The difference comes down to cv2.imread versus tf.decode_jpeg. So we use cv2.COLOR_RGB2HSV in line 4. Those two changes —numpy and RGB — ready the function for tensor input.

For comparison, let’s define a similar color mask function but without OpenCV. The module tf.image offers functions for color space conversion, and low-level APIs like tf.math cover the rest. Here’s our tf_color_mask:

100% tensor-friendly 👍

You may have noticed that the default values for lower_hsv and upper_hsv differ across functions. In particular, the hue thresholds are 20…90 in the cv function, but 28…128 in the tf version. Why the different values for the same leafy hue?

In TensorFlow, the HSV channels are floats with ranges 0 to 1. Integers are convenient, so our custom function accepts integer values 0 to 255 for lower_hsv and upper_hsv then converts and scales down for us.

In OpenCV, Saturation and Value are integers with ranges 0 to 255, while Hue has a different range, 0 to 179.

HSV Cylinder (wikipedia)

It is not uncommon for Hue coordinates to top out at 180 or 360. The HSV color space is represented by a cylinder, with Hue designated by degrees of arc. OpenCV’s scale for Hue easily maps to degrees. That factor of (255/179) explains the superficial difference between Hue defaults in our custom color-mask functions: 20*(255/179) approx equals 28 and 90*(255/179) approx equals 128.

Let’s see how our mask functions compare. No need to pipe the cassava quite yet; we can import our furry friend from Sci-Py for this quick check.

Functional equivalence. The better mask will be determined by execution time and pipeline compatibility.

So far, so good. But our goal is a smart crop. Imagine if the far-left image were center-cropped, then evaluated for a leaf disease. Not so smart. The leaf classifier would see little more than racoon.

Incidentally, the green channel of RGB isn’t very discerning. Not without our mask. The white and bright racoon parts dominate every channel of RGB. So our color mask certainly helps. It pushes those bright but irrelevant racoon parts to the void.

Now we can probe the masked image for patches of leafy content — e.g., by counting the number of non-zero pixels or summing pixel values. Where the masked image has a lot to say, we will crop the original image.

Custom crop: Loops vs Tensor Arrays

Let’s write the crop function a few ways for comparison. Tensorboard will declare a winner for us in the end. Again, the driving question is this: What kind of functions can we insert without impeding tf.data?

One way to find a suitably green area to crop is by looping over sample patches of an image. It’s hard to imagine a loop as an efficient computational graph, but its performance might surprise.

Consistent with TensorFlow guidelines, this loops over tf.range, not the python iterable range. Green pixel values are added with tf.reduce_sum rather than np.sum. Note that our function actually returns the coordinates used to make the crop; tf.image.crop_to_bounding_box will execute the crop.

Schematic of cv_color_mask then loop_crop, target_size=[300,300]. Dotted boxes are for demo, only.

We could do even more with TensorFlow’s Python API — e.g., we could initialize max_green, y1 and x1 as tf.Variables outside the function definition, then use the .assign method inside the function to update their values. Let’s see whether TensorFlow’s AutoGraph can handle those details.

AutoGraph does more than meet us halfway. It traverses the toughest terrain. To appreciate its role, compare the docs for tf.reduce_sum and tf.while_loop. With which are you more confident? 😄 Be thankful to handle tf.reduce_sum and let AutoGraph reconceptualize loop_crop as a multitude of conditionals in a computational graph.

Let’s meet the next contestant, a much different implementation. It obviates the loop and variables with a tensor array.

The selling point of a tensor array is that it isn’t a tensor constant, which is immutable, or a tensor variable, which is initialized only at trace and might require coordination across devices.

A tensor array is actually an array of tensors, and its methods .write and .read are simple enough. Then argmax allows us to compare operations across multiple tensors without the need for a loop.

Both these crop functions — loop_crop and array_crop — are deterministic. They compare a central patch with 4 side and 4 corner patches.

Schematic of tf_color_mask then array_crop, target_size=[224,224]. Red boxes are for demo, only.

Of course, the advantage of a random crop is that it isn’t deterministic. It might miss the best part of the frame this epoch, but it changes every epoch. Preprocessing meets augmentation — all good. So let’s invite one more contestant to step forward.

The following function, max_rand_crop, combines a loop and a tensor array. It also features a new argument, saccades. In the previous two functions, 9 patches were compared for green content. If we want more or fewer random patches compared, we could use max_rand_crop and set saccades to a number bigger or smaller than 9.

Yes, the function name is an oxymoron. It makes sense, though, right? From randomly sampled patches, maximize the green: max_rand_crop. We’re open to suggestions 💭

Concise and flexible. Good for augmentation at training and preprocessing at inference. Fewer saccades could mean faster execution and more variety across epochs. More saccades could ensure the ideal crop, at a cost. True to our goal, max_rand_crop is a smarter crop, when applied to a color-masked tensor input.

In max_rand_crop, patch could be defined more simply as a tensor slice. We used tf.roll, instead — honestly, just for fun. 😃 In other contexts, tf.roll could set up for a crop that wraps around the image boundaries in a way a slice could not. You might keep roll in mind for your next augmentation idea.

Schematic of tf_color_mask then max_rand_crop, target_size=[224,224], saccades=6. Dotted boxes are for demo, only.

A more significant piece of max_rand_crop is the random number generator, rng. TensorFlow now discourages use of tf.random.uniform. Its current recommendation is to initialize a random number generator — technically a variable — outside the function definition.

Depending on the distribution strategy, placement of the rng could get complicated. For example, if we planned to use a multi-core CPU or GPU, we might copy or split the random number generator. We won’t cross that bridge: Kaggle’s CPU and GPU are both 1 core, as of now, and its TPU is so fast that we prefer to let it be with our one rng.

Pipeline Integration

Our preprocessing functions are called from decode_image. With 2 mask functions and 3 ways to crop, we have 6 possible combos to clock. And we’d like to compare performance across CPU, GPU and TPU. 😕 And we’d be cheating ourselves if we didn’t try decorating all the candidate functions with @tf.function with and without input signatures to avoid retracing and whatnot. 😵 Let’s see how far we get.

Sidebar to troubleshoot cv_color_mask

There is a hitch. With cv_color_mask in the pipeline, our training is over before it begins. We can troubleshoot this, though. If you are familiar with py_function, jump to the next subsection.

Didn’t we confirm that cv_color_mask accepts a tensor input? It worked on the racoon tensor, right? Why the error message only now?

We checked cv_color_mask on the racoon in eager mode, not in the context of the TFRecords pipeline. When we call get_training_dataset(), the tf.data pipeline siphons everything to graph mode. More precisely, tf.data.Dataset.map traces read_tfrecord— which calls decode_image and cv_color_mask— and attempts to stage all their component ops as computational graphs.

From https://www.tensorflow.org/api_docs/python/tf/data/Dataset#mapNote that irrespective of the context in which [the mapped function] is defined (eager vs. graph), tf.data traces the function and executes it as a graph. To use Python code inside of the function you have a few options:1) Rely on AutoGraph to convert Python code into an equivalent graph computation. The downside of this approach is that AutoGraph can convert some but not all Python code.2) Use tf.py_function, which allows you to write arbitrary Python code but will generally result in worse performance than 1).3) Use tf.numpy_function, which also allows you to write arbitrary Python code. Note that tf.py_function accepts tf.Tensor whereas tf.numpy_function accepts numpy arrays and returns only numpy arrays.

AutoGraph handles a lot for us. Nevertheless, OpenCV and other imported Python libraries are beyond the purview of AutoGraph.

The error message — 'Tensor' object has no attribute 'numpy’ — is not super clear about the underlying issue. It points to the numpy attribute as a hindrance, when we included that bit to overcome compatibility issues. Nowhere in the error message is a Tensor contrasted with an EagerTensor, or is Dataset.map described as the staging catalyst.

You search stackoverflow, and the answers miss the mark. Uneasy feeling, right? Seriously, stackoverflow is a phenomenal resource. But in this case, most answers pertain to earlier versions of TensorFlow, when eager mode was not the default. So you read about sessions and whatnot. 😑 Deep breath.

For a background on eager vs graph mode, or imperative vs declarative programming, the Eager Team’s synopsis from February 2019 is educational. Of course, that article was released before the latest updates to TensorFlow. Still, it clarifies what the updates would entail and why. After reading that article, even the outdated stackoverflow posts will make more sense.

Back on track with tf.py_function

One of many improvements to TF2 is a more dexterous py_function. Even though cv_color_mask includes both cv2 and tf-native ops, the improved py_function sorts that out, and it’s compatible with GPU. 👍

To use cv_color_mask in the pipeline, we need the wrapper, py_function.

Every image has its shape

Notice that decode_image includes arguments for height and width. TensorFlow is picky about dimensions when it stages functions for graph mode. Why complicate things with options to vary shape?

Well, we couldn’t resist adding the 2019 cassava jpegs to the mix. 😛 Those images happen to come in all shapes and sizes. So we created a set of TFRecords that include most of the 2019 and 2020 images, along with image dimensions. Then we adjusted read_tfrecords, decode_image, and the preprocessing functions accordingly. We wrote the TFRecords like this and read them like this.

While these TFRecords are far from perfect, they contain over 20,000 unique images balanced across the 5 classes — 4 disease and 1 healthy. More importantly, by specifying every image’s dimensions, we can avoid any systematic skew of aspect ratio.

Data leakage is subtle and might not pertain to our cassava sets or aspect ratios; regardless, handling images individually seems a worthwhile pursuit. So, yeah, all original dimensions will come through read_tfrecords. Then our custom crop will return a single size for batching and model input.

Pipeline Comparisons

The training notebook used for our comparisons is here. The classifier to be trained is a bare-bones convolutional neural network. It is built within a strategy scope; of course, get_strategy() returns just the _DefaultDistributionStrategy on Kaggle’s CPU and GPU kernels, both of which have 1 core.

Every pipeline configuration was run three times, 5 epochs per run. Training times (ms/step) were converted to speeds (examples/sec) and averaged across epochs. For all runs on CPU and GPU, IMAGE_SIZE=[224,224] and BATCH_SIZE=16.

In the notebook, switching configurations is as simple as changing two lines in decode_image (see gif below).

2 line changes → 7 preprocessing routines
  • On CPU, cv_color_mask is faster than tf_color_mask and practically as fast as the baseline condition, which skips the color space conversion and custom crop altogether. (OpenCV is known to be performant: Albumentations chooses cv2.COLOR_RGB2HSV for its HueSaturationValue op under the hood.)
  • On GPU, tf_color_mask takes the lead, but cv_color_mask is definitely serviceable; in the upcoming section on Tensorboard, we clarify that tf_color_mask did not accelerate until explicitly placed on GPU.

Beyond comparison: TPU

Flip the switch on Kaggle’s TPU v3-8, and the same notebook used for training on CPU and GPU is practically unchanged. TensorFlow opts for TPUStrategy. With 8 cores available, we increase from BATCH_SIZE=16 to BATCH_SIZE=128. Experts at Kaggle would remind us to adjust the learning rate by a factor of 8, as well, but we aren’t concerned with that parameter today.

We re-ran training with the various combinations of color mask and crop. Again, 15 epochs per condition —5 epochs in each of 3 runs. So, how fast can the model consume images, with our preprocessing involved?

This bar graph raises a couple questions. First, why no data for cv_color_mask?

Alas, getting cv_color_mask through the TPU pipeline is not feasible — not via the set-up from Kaggle, anyway. Many have tried.

Lucky for us, tf_color_mask serves the same purpose and can be serialized for placement on TPU (or the CPU part of a TPU). The result: model training, with color-space conversion for preprocessing, done in the blink of an eye. 💫

Another question: If this TPU is so fast, why do the bars for tf_color_mask look so puny?

Nothing would look impressive next to this baseline condition, which bypasses the custom preprocessing with a simple center cut. In that baseline condition, the model devours images at warp speed: over 50,000 images 224x224x3 — gradient descent and all — in less than a minute. It’s incredible.

A follow-up question: Why is there such a large gap between that baseline condition and the tf_color_mask routine on TPU, when no such disparity occurred on CPU or GPU?

The CPU and GPU were both 1 core, prompting the default distribution strategy. It seems the TPU distribution strategy affords an opportunity for optimization? Seems complicated. Why get bogged down in the details? The glass is half full!

Check out the table for a better look at the wide disparity between tf_color_mask on TPU and all other smart-crop conditions. We collapsed across crop functions to get these color-mask averages; every number in the table summarizes 45 epochs in 9 runs.

Let’s recap:

  • cv_color_mask relies on py_function and cannot be serialized to TPU.
  • tf_color_mask, coupled with any of our custom crops, soars on TPU. Just flip the switch for faster training by almost an order of magnitude. And the memory of a TPU could accommodate much bigger images.

Behind the scenes with Tensorboard

Whereas fiddling with the TPU was not required, fiddling with the GPU was. Initially, Kaggle’s GPU accelerated cv_color_mask conditions as expected, but had very little impact on configurations involving tf_color_mask. A Tensorboard callback on model.fit was revealing…

Of course, with a bigger model and more gradients to compute, the GPU would be busier. That might leave our CPU enough time to preprocess and prepare each batch without delay. In the case of our relatively small CNN, the CPU was holding up the GPU’s work here, clearly.

By design, the tf.data pipeline leaves preprocessing to the CPU. So it’s no surprise that the tf_color_mask conditions were held up. But what happened with cv_color_mask and the GPU here?

It appears that the py_function default is to exploit the available GPU for any compatible ops it wraps, and this seems to override the tf.data norm.

In any event, we take the hint and insert with tf.device(“/device:GPU:0”): in line 2 of decode_image. For the bar graph presented earlier, GPU training speeds were collected with explicit device placement, as indicated in the gif below.

Explicit placement on GPU made all the difference, especially for tf_color_mask.

The new line in decode_image did not have a noticeable impact on training speeds with cv_color_mask, but completely changed the game for tf_color_mask. Tensorboard pies tell the story, too.

The big take-aways:

  • py_function used GPU for TensorFlow ops automatically, hence eager execution with cv_color_mask was faster than advertised.
  • The Tensorboard callback on model.fit clarified a great deal.
  • Effort and time spent investigating the GPU is another reason to appreciate the easy speed of tf_color_mask on TPU.

Portability

Our smart crop could impact training and inference, so portability matters.

For the same reasons that cv_color_mask is incompatible with the TPU pipeline — namely, its need for py_function — it cannot be exported via the SavedModel class. The imported library, OpenCV, ties cv_color_mask to its environment.

Can we confirm that tf_color_mask is potentially portable? Tensorboard can, yes. Earlier we profiled entire steps of training. We can also profile an individual function call, even outside the pipeline. In order to force eager TensorFlow to generate the graph, we first call tf_color_mask = tf.function(tf_color_mask) or decorate the function definition with @tf.function.

tf_color_mask 👆

TensorFlow is successful in generating this Python-independent graph of tf_color_mask, signifying that the function can be exported via the SavedModel class. The green indicates that tf_color_mask is also 100% TPU-compatible.

Incidentally, we can profile any function in Tensorboard, including cv_color_mask. However, to compel staging as a graph, we need tf.function, which conflicts with py_function. There will be no graph of cv_color_mask.

For a really deep dive, profile cv_color_mask with and then without its wrapper, py_function; or profile tf_color_mask with and then without its decorator, @tf.function; or profile a function on CPU, then again explicitly placed on GPU. Transparency with Tensorboard. It’s like an ant farm.

loop_crop

All three crop functions are potentially portable when decorated with tf.function and are 100% TPU-friendly (as we know from our training runs, as well).

The graph of loop_crop looks simpler than those of array_crop and max_rand_crop. Looks can be deceiving.

While viewing the same graphs in Tensorboard, you can double-click to expand nodes and see detail. The reason loop_crop appears simple here is that its sprawling auxiliary parts— while this and while that — are collapsed.

Regardless of looks, the crop functions worked in the tf.data pipeline at similar speeds. 👌 Whatever works.

array_crop 👆
max_rand_crop (saccades=6)

Conclusion

Most TF2 tutorials cover model layers and training loops — the big matmul stuff, by the batch. We focused on preprocessing individual images, instead.

What we learned is encouraging. The tf.data pipeline is amenable to your custom code, even when you rely on imported libraries, …to a point.

In order to reap the full benefits of model training on TPU — not to mention deployment — use tensor-friendly ops. Kaggle and Tensorboard can facilitate your trial-and-error.

Epilogue: Find a Place for Color Space in Neural Networks

Proof of principle requires example. Our example happened to hinge on a color space conversion and hue.

The color space conversion served several purposes, in fact. It helped us eliminate prejudicial context from cassava images. It augmented the dataset, with non-deterministic smart sampling.

We hope it reminded capable coders that the mind-blowing pipeline, TFRecords →TPU, accommodates simple creativity, too.

Did we mention that a $30,000 TPU from Kaggle is yours free 30 hours/week? Excited? Need a challenge? Find treasure in a color space beyond RGB.

A side of hue

In the human brain, some visual processing is color-sensitive, some is not. If our goal is to organize a neural network like primate visual cortex, we might convert RGB tensors to grayscale for one afferent stream and HSV for another; the color images would be processed more slowly, but catch up via skip connections (Chen et al., 2007).

But how much do these real-life details matter? ML engineers could start small, e.g. attaching hue summary stats to RGB or grayscale images as supplemental features.

Let’s say your plan involves transfer learning from a base model trained on RGB images. You need another RGB dataset, right? But hue could be extracted in the pipeline — via momentary conversion to HSV — and somehow summarized as a vector.

The model’s lower layers take your RGB images and return an RGB feature vector, per the usual. That’s when Keras concatenates the extracted RGB features with your hue summary vector. The concatenated result is sent through the final dense layers. An RGB entrée with a side of hue.

Immunohistochemistry in CIELAB color space

Nowhere is color space exploration more warranted than under the microscope, where the subject is stains.

An article from Geread et al. (2019) reminds us that HSV is not the only color space besides RGB. They found that, “…the b∗ channel of the L∗a∗b∗ color space… can be used to automatically separate blue and brown colors for effective DAB and H [hemotoxylin] stain separation.”

CIELAB color space (wikipedia)

It is common to treat tissue samples with multiple stains, e.g. to analyze overlap or co-localization. But then methods to quantify those stains must tell them apart. Apparently, the CIELAB color space (abbreviated L*a*b*) proved useful in that regard. In fact, Geread et al. report that their algorithm, or “unsupervised color separation framework” — largely based on histograms and skew — outperformed six ML classifiers.

Machine learning is no failure with immunological stains, though. Quite the contrary. Deep learning has proved impressive with all sorts of microscopy. Unfortunately, neural networks in this domain are hampered by a pervasive, unresolved problem…

The pervasive problem with stains

No, it’s not about separating stains on a slide. It’s about reconciling different images of identical stains, even identical slides. Bigger datasets would be great for model training, and researchers are willing to share. But somehow their images look incompatible.

The aforementioned article from Geread et al. is part of a special section in the journal Frontiers in Bioengineering and Biotechnology, where most of the 12 articles address this same pervasive problem: color normalization. The most readable of the bunch was probably by Pontalba et al.

A brief Medium article from Prof. Marc Aubreville clarifies the issue really well, too:

If a solution requires a change to how images are captured in the first place, that is doable. Durand et al. (2018) implemented machine learning at the image acquisition stage.

Their goal was not to normalize images from different labs, but to acquire optimal images on a particular scanner and day. Still, it would be interesting to compare output from different labs that adopt the same trained model for image acquisition.

All of these articles are accessible for free. All share data. Most even share their models and trained weights. Heroic. If computer vision and research don’t light your fire, remember that this is medicine, too. All hands on deck.

Fluoro-combo-blender, yeah!

Fluorochromes are carefully engineered to be distinguishable by hue. Obvious, right? Let’s see how this defining characteristic could be exploited in not-so-obvious ways. (This area of microscopy is loaded with tools and techniques; for thorough background, look here or here.)

Imagine that the challenge is cell-type classification. Cell types are characterized by patterns of protein expression, and several proteins are tagged with glowing bits.

The jpegs look like your classic light-bright toy: against a black background, specks of color — glowing green, cyan, yellow, red, etc. As many as 8 distinct colors might target various proteins in a single cell (though smaller palettes are more common).

Images are acquired via the same microscope used to visualize the prepared slides. Each fluorophore is excited (illuminated) and captured separately, and those still-shots are flattened or merged into a multi-colored jpeg.

Let’s assume that your TFRecords include the final, color-merged images. What if you undid that work in the pipeline, separating colors by hue? Easy-peasy: use tf_color_mask.

By author. These images were acquired separately at the microscope, but public datasets might include only the color-merge (right). In the pipeline, tf_color_mask could recreate the RFP (left) and GFP (center) images.

For every color or combination of colors, build a model. Train the separate models on variously masked portions of the HSV cylinder. Even a single model, or circumscribed set of fluorophore tags, might lead to accurate labeling of cell types.

Build one more model — the blender — to learn which intersections of color matter most, in which cases and how much. It sounds abstract, but model stacking is a common practice. It works well, as it should.

Model ensembles are a recent manifestation of a mainstay design principle. Most of the earliest neural networks derived from the same premise: parallel distributed processing. Recall that our primate visual cortex does it that way, too, with dorsal and ventral streams.

We can stack more color-combo models on a TPU than in our brain. So why not give it a try?

If stacking works well with masked fluorophores, other contexts come to mind. Thermal images. Satellite images.

Explore the space. It is easy to accept fluorescent microscopy and satellite imagery as the payload, and rest our mind’s eye. But look again — just beyond conversion there is adventure yet in those jpegs. There are vast pipelines where untold dimensions of color space lie in wait. When you get there, stay sharp. Run your ops on TPU and rally at the edge.

References, Resources & Links

Kaggle datasets

  • Cassava images from 2019 and 2020.
  • Custom TFRecords with image height and width as example features; this dataset is also attached to the article’s companion notebook.

Kaggle notebooks

Leaf expertise

OpenCV tutorials

TensorFlow background

Model deployment

Epilogue: Challenges in Color Space

  • Chi-Ming Chen, Peter Lakatos, Ankoor S. Shah, Ashesh D. Mehta, Syndee J. Givre, Daniel C. Javitt, Charles E. Schroeder. Functional Anatomy and Interaction of Fast and Slow Visual Pathways in Macaque Monkeys, Cerebral Cortex, Volume 17, Issue 7, July 2007, Pages 1561–1569, https://doi.org/10.1093/cercor/bhl067 Few labs have the tools to monitor extracellular field potentials in awake behaving monkeys. Fewer still can monitor such signals from multiple cortical layers, capturing laminar profiles with depth electrodes. When a lab can simultaneously record laminar profiles from multiple, distant areas of cortex while probing a cognizant macaque’s visual fields — fascinating stuff!
  • Ishikawa-Ankerhold HC, Ankerhold R, Drummen GPC. Advanced Fluorescence Microscopy Techniques — FRAP, FLIP, FLAP, FRET and FLIM. Molecules 2012, 17, 4047–4132. https://doi.org/10.3390/molecules17044047 → 86 pages of outstanding physics, pics, historical narrative and more.
  • Chudakov DM, Matz MV, Lukyanov S, Lukyanov KA. Fluorescent proteins and their applications in imaging living cells and tissues. Physiol Rev. 2010 Jul;90(3):1103–63. doi: 10.1152/physrev.00038.2009
  • https://medium.com/r/?url=https%3A%2F%2Flink.medium.com%2FrN1rKRVqafb Different whole-slide scanners yield images that are too different.
  • Geread RS, Morreale P, Dony RD, Brouwer E, Wood GA, Androutsos D, Khademi A (2019). IHC Color Histograms for Unsupervised Ki67 Proliferation Index Calculation. Front. Bioeng. Biotechnol. 7:226. doi: 10.3389/fbioe.2019.00226 Images can solve real problems, but RGB is not always the ideal color space.
  • Pontalba JT, Gwynne-Timothy T, David E, Jakate K, Androutsos D, Khademi A (2019). Assessing the Impact of Color Normalization in Convolutional Neural Network-Based Nuclei Segmentation Frameworks. Front. Bioeng. Biotechnol. 7:300. doi: 10.3389/fbioe.2019.00300
  • Otálora S, Atzori M, Andrearczyk V, Khan A and Müller H (2019). Staining Invariant Features for Improving Generalization of Deep Convolutional Neural Networks in Computational Pathology. Front. Bioeng. Biotechnol. 7:198. doi: 10.3389/fbioe.2019.00198
  • Durand, A., Wiesner, T., Gardner, MA. et al. A machine learning approach for online automated optimization of super-resolution optical microscopy. Nat Commun 9, 5247 (2018). https://doi.org/10.1038/s41467-018-07668-y
  • Rumelhart, D. E., McClelland, J. L. & the PDP Research Group. Parallel Distributed Processing: Explorations in the Microstructure of Cognition. Volume 1: Foundations (MIT Press, Cambridge, Massachusetts, 1986). https://mitpress.mit.edu/books/parallel-distributed-processing-volume-1 → The brain and mind sure work well together.
Please :)

--

--