Deep neural networks (DNNs) provide more accurate results as the size and coverage of their training data increases. While investing in high-quality and large-scale labeled datasets is one path to model improvement, another is leveraging prior knowledge, concisely referred to as “rules” — reasoning heuristics, equations, associative logic, or constraints. Consider a common example from physics where a model is given the task of predicting the next state in a double pendulum system. While the model may learn to estimate the total energy of the system at a given point in time only from empirical data, it will frequently overestimate the energy unless also provided an equation that reflects the known physical constraints, e.g., energy conservation. The model fails to capture such well-established physical rules on its own. How could one effectively teach such rules so that DNNs absorb the relevant knowledge beyond simply learning from the data?
In “Controlling Neural Networks with Rule Representations”, published at NeurIPS 2021, we present Deep Neural Networks with Controllable Rule Representations (DeepCTRL), an approach used to provide rules for a model agnostic to data type and model architecture that can be applied to any kind of rule defined for inputs and outputs. The key advantage of DeepCTRL is that it does not require retraining to adapt the rule strength. At inference, the user can adjust rule strength based on the desired operation point of accuracy. We also propose a novel input perturbation method, which helps generalize DeepCTRL to non-differentiable constraints. In real-world domains where incorporating rules is critical — such as physics and healthcare — we demonstrate the effectiveness of DeepCTRL in teaching rules for deep learning. DeepCTRL ensures that models follow rules more closely while also providing accuracy gains at downstream tasks, thus improving reliability and user trust in the trained models. Additionally, DeepCTRL enables novel use cases, such as hypothesis testing of the rules on data samples and unsupervised adaptation based on shared rules between datasets.
The benefits of learning from rules are multifaceted:
Learning Jointly from Rules and Tasks The conventional approach to implementing rules incorporates them by including them in the calculation of the loss. There are three limitations of this approach that we aim to address: (i) rule strength needs to be defined before learning (thus the trained model cannot operate flexibly based on how much the data satisfies the rule); (ii) rule strength is not adaptable to target data at inference if there is any mismatch with the training setup; and (iii) the rule-based objective needs to be differentiable with respect to learnable parameters (to enable learning from labeled data).
DeepCTRL modifies canonical training by creating rule representations, coupled with data representations, which is the key to enable the rule strength to be controlled at inference time. During training, these representations are stochastically concatenated with a control parameter, indicated by α, into a single representation. The strength of the rule on the output decision can be improved by increasing the value of α. By modifying α at inference, users can control the behavior of the model to adapt to unseen data.
Integrating Rules via Input Perturbations Training with rule-based objectives requires the objectives to be differentiable with respect to the learnable parameters of the model. There are many valuable rules that are non-differentiable with respect to input. For example, “higher blood pressure than 140 is likely to lead to cardiovascular disease” is a rule that is hard to be combined with conventional DNNs. We also introduce a novel input perturbation method to generalize DeepCTRL to non-differentiable constraints by introducing small perturbations (random noise) to input features and constructing a rule-based constraint based on whether the outcome is in the desired direction.
Use Cases We evaluate DeepCTRL on machine learning use cases from physics and healthcare, where utilization of rules is particularly important.
We quantify reliability of a model with the verification ratio, which is the fraction of output samples that satisfy the rules. Operating at a better verification ratio could be beneficial, especially if the rules are known to be always valid, as in natural sciences. By adjusting the control parameter α, a higher rule verification ratio, and thus more reliable predictions, can be achieved.
To demonstrate this, we consider the time-series data generated from double pendulum dynamics with friction from a given initial state. We define the task as predicting the next state of the double pendulum from the current state while imposing the rule of energy conservation. To quantify how much the rule is learned, we evaluate the verification ratio.
We compare the performance of DeepCTRL on this task to conventional baselines of training with a fixed rule-based constraint as a regularization term added to the objective, λ. The highest of these regularization coefficients provides the highest verification ratio (shown by the green line in the second graph below), however, the prediction error is slightly worse than that of λ = 0.1 (orange line). We find that the lowest prediction error of the fixed baseline is comparable to that of DeepCTRL, but the highest verification ratio of the fixed baseline is still lower, which implies that DeepCTRL could provide accurate predictions while following the law of energy conservation. In addition, we consider the benchmark of imposing the rule-constraint with Lagrangian Dual Framework (LDF) and demonstrate two results where its hyperparameters are chosen by the lowest mean absolute error (LDF-MAE) and the highest rule verification ratio (LDF-Ratio) on the validation set. The performance of the LDF method is highly sensitive to what the main constraint is and its output is not reliable (black and pink dashed lines).
Additionally, the figures above illustrate the advantage DeepCTRL has over conventional approaches. For example, increasing the rule strength λ from 0.1 to 1.0 improves the verification ratio (from 0.7 to 0.9), but does not improve the mean absolute error. Arbitrarily increasing λ will continue to drive the verification ratio closer to 1, but will result in worse accuracy. Thus, finding the optimal value of λ will require many training runs through the baseline model, whereas DeepCTRL can find the optimal value for the control parameter α much more quickly.
The strengths of some rules may differ between subsets of the data. For example, in disease prediction, the correlation between cardiovascular disease and higher blood pressure is stronger for older patients than younger patients. In such situations, when the task is shared but data distribution and the validity of the rule differ between datasets, DeepCTRL can adapt to the distribution shifts by controlling α.
Exploring this example, we focus on the task of predicting whether cardiovascular disease is present or not using a cardiovascular disease dataset. Given that higher systolic blood pressure is known to be strongly associated with cardiovascular disease, we consider the rule: “higher risk if the systolic blood pressure is higher”. Based on this, we split the patients into two groups: (1) unusual, where a patient has high blood pressure, but no disease or lower blood pressure, but has disease; and (2) usual, where a patient has high blood pressure and disease or low blood pressure, but no disease.
We demonstrate below that the source data do not always follow the rule, and thus the effect of incorporating the rule can depend on the source data. The test cross entropy, which indicates classification accuracy (lower cross entropy is better), vs. rule strength for source or target datasets with varying usual / unusual ratio are visualized below. The error monotonically increases as α → 1 because the enforcement of the imposed rule, which doesn’t accurately reflect the source data, becomes more strict.
When a trained model is transferred to the target domain, the error can be reduced by controlling α. To demonstrate this, we show three domain-specific datasets, which we call Target 1, 2, and 3. In Target 1, where the majority of patients are from the usual group, as α is increased, the rule-based representation has more weight and the resultant error decreases monotonically.
When the ratio of usual patients is decreased in Target 2 and 3, the optimal α is an intermediate value between 0 and 1. These demonstrate the capability to adapt the trained model via α.
Conclusions Learning from rules can be crucial for constructing interpretable, robust, and reliable DNNs. We propose DeepCTRL, a new methodology used to incorporate rules into data-learned DNNs. DeepCTRL enables controllability of rule strength at inference without retraining. We propose a novel perturbation-based rule encoding method to integrate arbitrary rules into meaningful representations. We demonstrate three use cases of DeepCTRL: improving reliability given known principles, examining candidate rules, and domain adaptation using the rule strength.
Acknowledgements We greatly appreciate the contributions of Jinsung Yoon, Xiang Zhang, Kihyuk Sohn and Tomas Pfister.
Deep machine learning (ML) systems have achieved considerable success in medical image analysis in recent years. One major contributing factor is access to abundant labeled datasets, which are used to train highly effective supervised deep learning models. However, in the real-world, these models may encounter samples exhibiting rare conditions that are individually too infrequent for per-condition classification. Nevertheless, such conditions can be collectively common because they follow a long-tail distribution and when taken together can represent a significant portion of cases — e.g., in a recent deep learning dermatological study, hundreds of rare conditions composed around 20% of cases encountered by the model at test time.
To prevent models from generating erroneous outputs on rare samples at test time, there remains a considerable need for deep learning systems with the ability to recognize when a sample is not a condition it can identify. Detecting previously unseen conditions can be thought of as an out-of-distribution (OOD) detection task. By successfully identifying OOD samples, preventive measures can be taken, like abstaining from prediction or deferring to a human expert.
Traditional computer vision OOD detection benchmarks work to detect dataset distribution shifts. For example, a model may be trained on CIFAR images but be presented with street view house numbers (SVHN) as OOD samples, two datasets with very different semantic meanings. Other benchmarks seek to detect slight differences in semantic information, e.g., between images of a truck and a pickup truck, or two different skin conditions. The semantic distribution shifts in such near-OOD detection problems are more subtle in comparison to dataset distribution shifts, and thus, are harder to detect.
In “Does Your Dermatology Classifier Know What it Doesn’t Know? Detecting the Long-Tail of Unseen Conditions”, published in Medical Image Analysis, we tackle this near-OOD detection task in the application of dermatology image classification. We propose a novel hierarchical outlier detection (HOD) loss, which leverages existing fine-grained labels of rare conditions from the long tail and modifies the loss function to group unseen conditions and improve identification of these near OOD categories. Coupled with various representation learning methods and the diverse ensemble strategy, this approach enables us to achieve better performance for detecting OOD inputs.
The Near-OOD Dermatology Dataset We curated a near-OOD dermatology dataset that includes 26 inlier conditions, each of which are represented by at least 100 samples, and 199 rare conditions considered to be outliers. Outlier conditions can have as low as one sample per condition. The separation criteria between inlier and outlier conditions can be specified by the user. Here the cutoff sample size between inlier and outlier was 100, consistent with our previous study. The outliers are further split into training, validation, and test sets that are intentionally mutually exclusive to mimic real-world scenarios, where rare conditions shown during test time may have not been seen in training.
Hierarchical Outlier Detection Loss We propose to use “known outlier” samples during training that are leveraged to aid detection of “unknown outlier” samples during test time. Our novel hierarchical outlier detection (HOD) loss performs a fine-grained classification of individual classes for all inlier or outlier classes and, in parallel, a coarse-grained binary classification of inliers vs. outliers in a hierarchical setup (see the figure below). Our experiments confirmed that HOD is more effective than performing a coarse-grained classification followed by a fine-grained classification, as this could result in a bottleneck that impacted the performance of the fine-grained classifier.
We use the sum of the predictive probabilities of the outlier classes as the OOD score. As a primary OOD detection metric we use the area under receiver operating characteristics (AUROC) curve, which ranges between 0 and 1 and gives us a measure of separability between inliers and outliers. A perfect OOD detector, which separates all inliers from outliers, is assigned an AUROC score of 1. A popular baseline method, called reject bucket, separates each inlier individually from the outliers, which are grouped into a dedicated single abstention class. In addition to a fine-grained classification for each individual inlier and outlier classes, the HOD loss–based approach separates the inliers collectively from the outliers with a coarse-grained prediction loss, resulting in better generalization. While similar, we demonstrate that our HOD loss–based approach outperforms other baseline methods that leverage outlier data during training, achieving an AUROC score of 79.4% on the benchmark, a significant improvement over that of reject bucket, which achieves 75.6%.
Representation Learning and the Diverse Ensemble Strategy We also investigate how different types of representation learning help in OOD detection in conjunction with HOD by pretraining on ImageNet, BiT-L, SimCLR and MICLe models. We observe that including HOD loss improves OOD performance compared to the reject bucket baseline method for all four representation learning methods.
Another orthogonal approach for improving OOD detection performance and accuracy is deep ensemble, which aggregates outputs from multiple independently trained models to provide a final prediction. We build upon deep ensemble, but instead of using a fixed architecture with a fixed pre-training, we combine different representation learning architectures (ImageNet, BiT-L, SimCLR and MICLe) and introduce objective loss functions (HOD and reject bucket). We call this a diverse ensemble strategy, which we demonstrate outperforms the deep ensemble for OOD performance and inlier accuracy.
Downstream Clinical Trust Analysis While we mainly focus on improving the performance for OOD detection, the ultimate goal for our dermatology model is to have high accuracy in predicting inlier and outlier conditions. We go beyond traditional performance metrics and introduce a “penalty” matrix that jointly evaluates inlier and outlier predictions for model trust analysis to approximate downstream impact. For a fixed confidence threshold, we count the following types of mistakes: (i) incorrect inlier predictions (i.e., mistaking inlier condition A as inlier condition B); (ii) incorrect abstention of inliers (i.e., abstaining from making a prediction for an inlier); and (iii) incorrect prediction for outliers as one of the inlier classes.
To account for the asymmetrical consequences of the different types of mistakes, penalties can be 0, 0.5, or 1. Both incorrect inlier and outlier-as-inlier predictions can potentially erode user trust in the model and were penalized with a score of 1. Incorrect abstention of an inlier as an outlier was penalized with a score of 0.5, indicating that potential model users should seek additional guidance given the model-expressed uncertainty or abstention. For correct decisions no cost is incurred, indicated by a score of 0.
1 (Incorrect, mistakes that may erode trust)
Because real-world scenarios are more complex and contain a variety of unknown variables, the numbers used here represent simplifications to enable qualitative approximations for the downstream impact on user trust of outlier detection models, which we refer to as “cost”. We use the penalty matrix to estimate a downstream cost on the test set and compare our method against the baseline, thereby making a stronger case for its effectiveness in real-world scenarios. As shown in the plot below, our proposed solution incurs a much lower estimated cost in comparison to baseline over all possible operating points.
Conclusion In real-world deployment, medical ML models may encounter conditions that were not seen in training, and it’s important that they accurately identify when they do not know a specific condition. Detecting those OOD inputs is an important step to improving safety. We develop an HOD loss that leverages outlier data during training, and combine it with pre-trained representation learning models and a diverse ensemble to further boost performance, significantly outperforming the baseline approach on our new dermatology benchmark dataset. We believe that our approach, aligned with our AI Principles, can aid successful translation of ML algorithms into real-world scenarios. Although we have primarily focused on OOD detection for dermatology, most of our contributions are fairly generic and can be easily incorporated into OOD detection for other applications.
Acknowledgements We would like to thank Shekoofeh Azizi, Aaron Loh, Vivek Natarajan, Basil Mustafa, Nick Pawlowski, Jan Freyberg, Yuan Liu, Zach Beaver, Nam Vo, Peggy Bui, Samantha Winter, Patricia MacWilliams, Greg S. Corrado, Umesh Telang, Yun Liu, Taylan Cemgil, Alan Karthikesalingam, Balaji Lakshminarayanan, and Jim Winkens for their contributions. We would also like to thank Tom Small for creating the post animation.
Quantum processors are made of superconducting quantum bits (qubits) that — being quantum objects — are highly susceptible to even tiny amounts of environmental noise. This noise can cause errors in quantum computation that need to be addressed to continue advancing quantum computers. Our Sycamore processors are installed in specially designed cryostats, where they are sealed away from stray light and electromagnetic fields and are cooled down to very low temperatures to reduce thermal noise.
However, the world is full of high-energy radiation. In fact, there’s a tiny background of high-energy gamma rays and muons that pass through everything around us all the time. While these particles interact so weakly that they don’t cause any harm in our day-to-day lives, qubits are sensitive enough that even weak particle interactions can cause significant interference.
In “Resolving Catastrophic Error Bursts from Cosmic Rays in Large Arrays of Superconducting Qubits”, published in Nature Physics, we identify the effects of these high-energy particles when they impact the quantum processor. To detect and study individual impact events, we use new techniques in rapid, repetitive measurement to operate our processor like a particle detector. This allows us to characterize the resulting burst of errors as they spread through the chip, helping to better understand this important source of correlated errors.
The Dynamics of a High-Energy Impact The Sycamore quantum processor is constructed with a very thin layer of superconducting aluminum on a silicon substrate, onto which a pattern is etched to define the qubits. At the center of each qubit is the Josephson junction, a superconducting component that defines the distinct energy levels of the qubit, which are used for computation. In a superconducting metal, electrons bind together into a macroscopic, quantum state, which allows electrons to flow as a current with zero resistance (a supercurrent). In superconducting qubits, information is encoded in different patterns of oscillating supercurrent going back and forth through the Josephson junction.
If enough energy is added to the system, the superconducting state can be broken up to produce quasiparticles. These quasiparticles are a problem, as they can absorb energy from the oscillating supercurrent and jump across the Josephson junction, which changes the qubit state and produces errors. To prevent any energy from being absorbed by the chip and producing quasiparticles, we use extensive shielding for electric and magnetic fields, and powerful cryogenic refrigerators to keep the chip near absolute zero temperature, thus minimizing the thermal energy.
A source of energy that we can’t effectively shield against is high-energy radiation, which includes charged particles and photons that can pass straight through most materials. One source of these particles are tiny amounts of radioactive elements that can be found everywhere, e.g., in building materials, the metal that makes up our cryostats, and even in the air. Another source is cosmic rays, which are extremely energetic particles produced by supernovae and black holes. When cosmic rays impact the upper atmosphere, they create a shower of high-energy particles that can travel all the way down to the surface and through our chip. Between radioactive impurities and cosmic ray showers, we expect a high energy particle to pass through a quantum chip every few seconds.
When one of these particles impinges on the chip, it passes straight through and deposits a small amount of its energy along its path through the substrate. Even a small amount of energy from these particles is a very large amount of energy for the qubits. Regardless of where the impact occurs, the energy quickly spreads throughout the entire chip through quantum vibrations called phonons. When these phonons hit the aluminum layer that makes up the qubits, they have more than enough energy to break the superconducting state and produce quasiparticles. So many quasiparticles are produced that the probability of the qubits interacting with one becomes very high. We see this as a sudden and significant increase in errors over the whole chip as those quasiparticles absorb energy from the qubits. Eventually, as phonons escape and the chip cools, these quasiparticles recombine back into the superconducting state, and the qubit error rates slowly return to normal.
Detecting Particles with a Computer The Sycamore processor is designed to perform quantum error correction (QEC) to improve the error rates and enable it to execute a variety of quantum algorithms. QEC provides an effective way of identifying and mitigating errors, provided they are sufficiently rare and independent. However, in the case of a high-energy particle going through the chip, all of the qubits will experience high error rates until the event cools off, producing a correlated error burst that QEC won’t be able to correct. In order to successfully perform QEC, we first have to understand what these impact events look like on the processor, which requires operating it like a particle detector.
To do so, we take advantage of recent advances in qubit state preparation and measurement to quickly prepare each qubit in their excited state, similar to flipping a classical bit from 0 to 1. We then wait for a short idle time and measure whether they are still excited. If the qubits are behaving normally, almost all of them will be. Further, the qubits that experience a decay out of their excited state won’t be correlated, meaning the qubits that have errors will be randomly distributed over the chip.
However, during the experiment we occasionally observe large error bursts, where all the qubits on the chip suddenly become more error prone all at once. This correlated error burst is a clear signature of a high-energy impact event. We also see that, while all qubits on the chip are affected by the event, the qubits with the highest error rates are all concentrated in a “hotspot” around the impact site, where slightly more energy is deposited into the qubit layer by the spreading phonons.
Next Steps Because these error bursts are severe and quickly cover the whole chip, they are a type of correlated error that QEC is unable to correct. Therefore, it’s very important to find a solution to mitigate these events in future processors that are expected to rely on QEC.
Shielding against these particles is very difficult and typically requires careful engineering and design of the cryostat and many meters of shielding, which becomes more impractical as processors grow in size. Another approach is to modify the chip, allowing it to tolerate impacts without causing widespread correlated errors. This is an approach taken in other complex superconducting devices like detectors for astronomical telescopes, where it’s not possible to use shielding. Examples of such mitigation strategies include adding additional metal layers to the chip to absorb phonons and prevent them from getting to the qubit, adding barriers in the chip to prevent phonons spreading over long distances, and adding traps for quasiparticles in the qubits themselves. By employing these techniques, future processors will be much more robust to these high-energy impact events.
As the error rates of quantum processors continue to decrease, and as we make progress in building a prototype of an error-corrected logical qubit, we're increasingly pushed to study more exotic sources of error. While QEC is a powerful tool for correcting many kinds of errors, understanding and correcting more difficult sources of correlated errors will become increasingly important. We’re looking forward to future processor designs that can handle high energy impacts and enable the first experimental demonstrations of working quantum error correction.
Acknowledgements This work wouldn’t have been possible without the contributions of the entire Google Quantum AI Team, especially those who worked to design, fabricate, install and calibrate the Sycamore processors used for this experiment. Special thanks to Rami Barends and Lev Ioffe, who led this project.
Image matting is the process of extracting a precise alpha matte that separates foreground and background objects in an image. This technique has been traditionally used in the filmmaking and photography industry for image and video editing purposes, e.g., background replacement, synthetic bokeh and other visual effects. Image matting assumes that an image is a composite of foreground and background images, and hence, the intensity of each pixel is a linear combination of the foreground and the background.
In the case of traditional image segmentation, the image is segmented in a binary manner, in which a pixel either belongs to the foreground or background. This type of segmentation, however, is unable to deal with natural scenes that contain fine details, e.g., hair and fur, which require estimating a transparency value for each pixel of the foreground object.
Alpha mattes, unlike segmentation masks, are usually extremely precise, preserving strand-level hair details and accurate foreground boundaries. While recent deep learning techniques have shown their potential in image matting, many challenges remain, such as generation of accurate ground truth alpha mattes, improving generalization on in-the-wild images and performing inference on mobile devices treating high-resolution images.
With the Pixel 6, we have significantly improved the appearance of selfies taken in Portrait Mode by introducing a new approach to estimate a high-resolution and accurate alpha matte from a selfie image. When synthesizing the depth-of-field effect, the usage of the alpha matte allows us to extract a more accurate silhouette of the photographed subject and have a better foreground-background separation. This allows users with a wide variety of hairstyles to take great-looking Portrait Mode shots using the selfie camera. In this post, we describe the technology we used to achieve this improvement and discuss how we tackled the challenges mentioned above.
Portrait Matting In designing Portrait Matting, we trained a fully convolutional neural network consisting of a sequence of encoder-decoder blocks to progressively estimate a high-quality alpha matte. We concatenate the input RGB image together with a coarse alpha matte (generated using a low-resolution person segmenter) that is passed as an input to the network. The new Portrait Matting model uses a MobileNetV3 backbone and a shallow (i.e., having a low number of layers) decoder to first predict a refined low-resolution alpha matte that operates on a low-resolution image. Then we use a shallow encoder-decoder and a series of residual blocks to process a high-resolution image and the refined alpha matte from the previous step. The shallow encoder-decoder relies more on lower-level features than the previous MobileNetV3 backbone, focusing on high-resolution structural features to predict final transparency values for each pixel. In this way, the model is able to refine an initial foreground alpha matte and accurately extract very fine details like hair strands. The proposed neural network architecture efficiently runs on Pixel 6 using Tensorflow Lite.
Most recent deep learning work for image matting relies on manually annotated per-pixel alpha mattes used to separate the foreground from the background that are generated with image editing tools or green screens. This process is tedious and does not scale for the generation of large datasets. Also, it often produces inaccurate alpha mattes and foreground images that are contaminated (e.g., by reflected light from the background, or “green spill”). Moreover, this does nothing to ensure that the lighting on the subject appears consistent with the lighting in the new background environment.
To address these challenges, Portrait Matting is trained using a high-quality dataset generated using a custom volumetric capture system, Light Stage. Compared with previous datasets, this is more realistic, as relighting allows the illumination of the foreground subject to match the background. Additionally, we supervise the training of the model using pseudo–ground truth alpha mattes from in-the-wild images to improve model generalization, explained below. This ground truth data generation process is one of the key components of this work.
Ground Truth Data Generation To generate accurate ground truth data, Light Stage produces near-photorealistic models of people using a geodesic sphere outfitted with 331 custom color LED lights, an array of high-resolution cameras, and a set of custom high-resolution depth sensors. Together with Light Stage data, we compute accurate alpha mattes using time-multiplexed lights and a previously recorded “clean plate”. This technique is also known as ratio matting.
Then, we extrapolate the recorded alpha mattes to all the camera viewpoints in Light Stage using a deep learning–based matting network that leverages captured clean plates as an input. This approach allows us to extend the alpha mattes computation to unconstrained backgrounds without the need for specialized time-multiplexed lighting or a clean background. This deep learning architecture was solely trained using ground truth mattes generated using the ratio matting approach.
Leveraging the reflectance field for each subject and the alpha matte generated with our ground truth matte generation system, we can relight each portrait using a given HDR lighting environment. We composite these relit subjects into backgrounds corresponding to the target illumination following the alpha blending equation. The background images are then generated from the HDR panoramas by positioning a virtual camera at the center and ray-tracing into the panorama from the camera’s center of projection. We ensure that the projected view into the panorama matches its orientation as used for relighting. We use virtual cameras with different focal lengths to simulate the different fields-of-view of consumer cameras. This pipeline produces realistic composites by handling matting, relighting, and compositing in one system, which we then use to train the Portrait Matting model.
Training Supervision Using In-the-Wild Portraits To bridge the gap between portraits generated using Light Stage and in-the-wild portraits, we created a pipeline to automatically annotate in-the-wild photos generating pseudo–ground truth alpha mattes. For this purpose, we leveraged the Deep Matting model proposed in Total Relighting to create an ensemble of models that computes multiple high-resolution alpha mattes from in-the-wild images. We ran this pipeline on an extensive dataset of portrait photos captured in-house using Pixel phones. Additionally, during this process we performed test-time augmentation by doing inference on input images at different scales and rotations, and finally aggregating per-pixel alpha values across all estimated alpha mattes.
Generated alpha mattes are visually evaluated with respect to the input RGB image. The alpha mattes that are perceptually correct, i.e., following the subject's silhouette and fine details (e.g., hair), are added to the training set. During training, both datasets are sampled using different weights. Using the proposed supervision strategy exposes the model to a larger variety of scenes and human poses, improving its predictions on photos in the wild (model generalization).
Portrait Mode Selfies The Portrait Mode effect is particularly sensitive to errors around the subject boundary (see image below). For example, errors caused by the usage of a coarse alpha matte keep sharp focus on background regions near the subject boundaries or hair area. The usage of a high-quality alpha matte allows us to extract a more accurate silhouette of the photographed subject and improve foreground-background separation.
Try It Out Yourself We have made front-facing camera Portrait Mode on the Pixel 6 better by improving alpha matte quality, resulting in fewer errors in the final rendered image and by improving the look of the blurred background around the hair region and subject boundary. Additionally, our ML model uses diverse training datasets that cover a wide variety of skin tones and hair styles. You can try this improved version of Portrait Mode by taking a selfie shot with the new Pixel 6 phones.
Acknowledgments This work wouldn’t have been possible without Sergio Orts Escolano, Jana Ehmann, Sean Fanello, Christoph Rhemann, Junlan Yang, Andy Hsu, Hossam Isack, Rohit Pandey, David Aguilar, Yi Jinn, Christian Hane, Jay Busch, Cynthia Herrera, Matt Whalen, Philip Davidson, Jonathan Taylor, Peter Lincoln, Geoff Harvey, Nisha Masharani, Alexander Schiffhauer, Chloe LeGendre, Paul Debevec, Sofien Bouaziz, Adarsh Kowdle, Thabo Beeler, Chia-Kai Liang and Shahram Izadi. Special thanks to our photographers James Adamson, Christopher Farro and Cort Muller who took numerous test photographs for us.
Birds are all around us, and just by listening, we can learn many things about our environment. Ecologists use birds to understand food systems and forest health — for example, if there are more woodpeckers in a forest, that means there’s a lot of dead wood. Because birds communicate and mark territory with songs and calls, it’s most efficient to identify them by ear. In fact, experts may identify up to 10x as many birds by ear as by sight.
In recent years, autonomous recording units (ARUs) have made it easy to capture thousands of hours of audio in forests that could be used to better understand ecosystems and identify critical habitat. However, manually reviewing the audio data is very time consuming, and experts in birdsong are rare. But an approach based on machine learning (ML) has the potential to greatly reduce the amount of expert review needed for understanding a habitat.
However, ML-based audio classification of bird species can be challenging for several reasons. For one, birds often sing over one another, especially during the “dawn chorus” when many birds are most active. Also, there aren’t clear recordings of individual birds to learn from — almost all of the available training data is recorded in noisy outdoor conditions, where other sounds from the wind, insects, and other environmental sources are often present. As a result, existing birdsong classification models struggle to identify quiet, distant and overlapping vocalizations. Additionally, some of the most common species often appear unlabeled in the background of training recordings for less common species, leading models to discount the common species. These difficult cases are very important for ecologists who want to identify endangered or invasive species using automated systems.
To address the general challenge of training ML models to automatically separate audio recordings without access to examples of isolated sounds, we recently proposed a new unsupervised method called mixture invariant training (MixIT) in our paper, “Unsupervised Sound Separation Using Mixture Invariant Training”. Moreover, in our new paper, “Improving Bird Classification with Unsupervised Sound Separation,” we use MixIT training to separate birdsong and improve species classification. We found that including the separated audio in the classification improves precision and classification quality on three independent soundscape datasets. We are also happy to announce the open-source release of the birdsong separation models on GitHub.
Birdsong Audio Separation MixIT learns to separate single-channel recordings into multiple individual tracks, and can be trained entirely with noisy, real-world recordings. To train the separation model, we create a “mixture of mixtures” (MoM) by mixing together two real-world recordings. The separation model then learns to take the MoM apart into many channels to minimize a loss function that uses the two original real-world recordings as ground-truth references. The loss function uses these references to group the separated channels such that they can be mixed back together to recreate the two original real-world recordings. Since there’s no way to know how the different sounds in the MoM were grouped together in the original recordings, the separation model has no choice but to separate the individual sounds themselves, and thus learns to place each singing bird in a different output audio channel, also separate from wind and other background noise.
We trained a new MixIT separation model using birdsong recordings from Xeno-Canto and the Macaulay Library. We found that for separating birdsong, this new model outperformed a MixIT separation model trained on a large amount of general audio from the AudioSet dataset. We measure the quality of the separation by mixing two recordings together, applying separation, and then remixing the separated audio channels such that they reconstruct the original two recordings. We measure the signal-to-noise ratio (SNR) of the remixed audio relative to the original recordings. We found that the model trained specifically for birds achieved 6.1 decibels (dB) better SNR than the model trained on AudioSet (10.5 dB vs 4.4 dB). Subjectively, we also found many examples where the system worked incredibly well, separating very difficult to distinguish calls in real-world data.
The following videos demonstrate separation of birdsong from two different regions (Caples and the High Sierras). The videos show the mel-spectrogram of the mixed audio (a 2D image that shows the frequency content of the audio over time) and highlight the audio separated into different tracks.
Classifying Bird Species To classify birds in real-world audio captured with ARUs, we first split the audio into five-second segments and then create a mel-spectrogram of each segment. We then train an EfficientNet classifier to identify bird species from the mel-spectrogram images, training on audio from Xeno-Canto and the Macaulay Library. We trained two separate classifiers, one for species in the Sierra Nevada mountains and one for upstate New York. Note that these classifiers are not trained on separated audio; that’s an area for future improvement.
We also introduced some new techniques to improve classifier training. Taxonomic training asks the classifier to provide labels for each level of the species taxonomy (genus, family, and order), which allows the model to learn groupings of species before learning the sometimes-subtle differences between similar species. Taxonomic training also allows the model to benefit from expert information about the taxonomic relationships between different species. We also found that random low-pass filtering was helpful for simulating distant sounds during training: As an audio source gets further away, the high-frequency parts fade away before the low-frequency parts. This was particularly effective for identifying species from the High Sierras region, where birdsongs cover very long distances, unimpeded by trees.
Classifying Separated Audio We found that separating audio with the new MixIT model before classification improved the classifier performance on three independent real-world datasets. The separation was particularly successful for identification of quiet and background birds, and in many cases helped with overlapping vocalizations as well.
The separation model does have some potential limitations. Occasionally we observe over-separation, where a single song is broken into multiple channels, which can cause misclassifications. We also notice that when multiple birds are vocalizing, the most prominent song often gets a lower score after separation. This may be due to loss of environmental context or other artifacts introduced by separation that do not appear during classifier training. For now, we get the best results by running the classifier on the separated channels and the original audio, and taking the maximum score for each species. We expect that further work will allow us to reduce over-separation and find better ways to combine separation and classification. You can see and hear more examples of the full system at our GitHub repo.
Future Directions We are currently working with partners at the California Academy of Sciences to understand how habitat and species mix changes after prescribed fires and wildfires, applying these models to ARU audio collected over many years.
We also foresee many potential applications for the unsupervised separation models in ecology, beyond just birds. For example, the separated audio can be used to create better acoustic indices, which could measure ecosystem health by tracking the total activity of birds, insects, and amphibians without identifying particular species. Similar methods could also be adapted for use underwater to track coral reef health.
Acknowledgements We would like to thank Mary Clapp, Jack Dumbacher, and Durrell Kapan from the California Academy of Sciences for providing extensive annotated soundscapes from the Sierra Nevadas. Stefan Kahl and Holger Klinck from the Cornell Lab of Ornithology provided soundscapes from Sapsucker Woods. Training data for both the separation and classification models came from Xeno-Canto and the Macaulay Library. Finally, we would like to thank Julie Cattiau, Lauren Harrell, Matt Harvey, and our co-author, John Hershey, from the Google Bioacoustics and Sound Separation teams.
Language models are becoming more capable than ever before and are helpful in a variety of tasks — translating one language into another, summarizing a long document into a brief highlight, or answering information-seeking questions. Among these, open-domain dialog, where a model needs to be able to converse about any topic, is probably one of the most difficult, with a wide range of potential applications and open challenges. In addition to producing responses that humans judge as sensible, interesting, and specific to the context, dialog models should adhere to Responsible AI practices, and avoid making factual statements that are not supported by external information sources.
Today we’re excited to share recent advances in our “LaMDA: Language Models for Dialog Applications” project. In this post, we’ll give an overview on how we’re making progress towards safe, grounded, and high-quality dialog applications. LaMDA is built by fine-tuning a family of Transformer-based neural language models specialized for dialog, with up to 137B model parameters, and teaching the models to leverage external knowledge sources.
Objectives & Metrics Defining objectives and metrics is critical to guide training dialog models. LaMDA has three key objectives — Quality, Safety, and Groundedness — each of which we measure using carefully designed metrics:
Quality: We decompose Quality into three dimensions, Sensibleness, Specificity, and Interestingness (SSI), which are evaluated by human raters. Sensibleness refers to whether the model produces responses that make sense in the dialog context (e.g., no common sense mistakes, no absurd responses, and no contradictions with earlier responses). Specificity is measured by judging whether the system's response is specific to the preceding dialog context, and not a generic response that could apply to most contexts (e.g., “ok” or “I don’t know”). Finally, Interestingness measures whether the model produces responses that are also insightful, unexpected or witty, and are therefore more likely to create better dialog.
Safety: We’re also making progress towards addressing important questions related to the development and deployment of Responsible AI. Our Safety metric is composed of an illustrative set of safety objectives that captures the behavior that the model should exhibit in a dialog. These objectives attempt to constrain the model’s output to avoid any unintended results that create risks of harm for the user, and to avoid reinforcing unfair bias. For example, these objectives train the model to avoid producing outputs that contain violent or gory content, promote slurs or hateful stereotypes towards groups of people, or contain profanity. Our research towards developing a practical Safety metric represents very early work, and there is still a great deal of progress for us to make in this area.
Groundedness: The current generation of language models often generate statements that seem plausible, but actually contradict facts established in known external sources. This motivates our study of groundedness in LaMDA. Groundedness is defined as the percentage of responses with claims about the external world that can be supported by authoritative external sources, as a share of all responses containing claims about the external world. A related metric, Informativeness, is defined as the percentage of responses with information about the external world that can be supported by known sources, as a share of all responses. Therefore, casual responses that do not carry any real world information (e.g., “That’s a great idea”), affect Informativeness but not Groundedness. While grounding LaMDA generated responses in known sources does not in itself guarantee factual accuracy, it allows users or external systems to judge the validity of a response based on the reliability of its source.
LaMDA Pre-Training With the objectives and metrics defined, we describe LaMDA’s two-stage training: pre-training and fine-tuning. In the pre-training stage, we first created a dataset of 1.56T words — nearly 40 times more words than what were used to train previous dialog models — from public dialog data and other public web documents. After tokenizing the dataset into 2.81T SentencePiece tokens, we pre-train the model using GSPMD to predict every next token in a sentence, given the previous tokens. The pre-trained LaMDA model has also been widely used for natural language processing research across Google, including program synthesis, zero-shot learning, style transfer, as well as in the BIG-bench workshop.
LaMDA Fine-Tuning In the fine-tuning stage, we train LaMDA to perform a mix of generative tasks to generate natural-language responses to given contexts, and classification tasks on whether a response is safe and high-quality, resulting in a single multi-task model that can do both. The LaMDA generator is trained to predict the next token on a dialog dataset restricted to back-and-forth dialog between two authors, while the LaMDA classifiers are trained to predict the Safety and Quality (SSI) ratings for the response in context using annotated data. During a dialog, the LaMDA generator first generates several candidate responses given the current multi-turn dialog context, and the LaMDA classifiers predict the SSI and Safety scores for every response candidate. Candidate responses with low Safety scores are first filtered out. Remaining candidates are re-ranked by their SSI scores, and the top result is selected as the response. We further filter the training data used for the generation task with LaMDA classifiers to increase the density of high-quality response candidates.
Factual Grounding While people are capable of checking their facts by using tools and referencing established knowledge bases, many language models draw their knowledge on their internal model parameters only. To improve the groundedness of LaMDA’s original response, we collect a dataset of dialogs between people and LaMDA, which are annotated with information retrieval queries and the retrieved results where applicable. We then fine-tune LaMDA’s generator and classifier on this dataset to learn to call an external information retrieval system during its interaction with the user to improve the groundedness of its responses. While this is very early work, we’re seeing promising results.
Evaluation In order to quantify progress against our key metrics, we collect responses from the pre-trained model, fine-tuned model, and human raters (i.e., human-generated responses) to multi-turn two-author dialogs, and then ask a different set of human raters a series of questions to evaluate these responses against the Quality, Safety, and Groundedness metrics.
We observe that LaMDA significantly outperforms the pre-trained model in every dimension and across all model sizes. Quality metrics (Sensibleness, Specificity, and Interestingness, in the first column below) generally improve with the number of model parameters, with or without fine-tuning. Safety does not seem to benefit from model scaling alone, but it does improve with fine-tuning. Groundedness improves as model size increases, perhaps because larger models have a greater capacity to memorize uncommon knowledge, but fine-tuning allows the model to access external knowledge sources and effectively shift some of the load of remembering knowledge to an external knowledge source. With fine-tuning, the quality gap to human levels can be narrowed, though the model’s performance remains below human levels in safety and groundedness.
Future Research & Challenges LaMDA’s level of Sensibleness, Specificity and Interestingness unlocks new avenues for understanding the benefits and risks of open-ended dialog agents. It also presents encouraging evidence that key challenges with neural language models, such as using a safety metric and improving groundedness, can improve with larger models and fine-tuning with more well-labeled data. However, this is very early work, and there are significant limitations. Exploring new ways to improve our Safety metric and LaMDA's groundedness, aligned with our AI Principles, will continue to be our main areas of focus going forward.
Acknowledgements We'd to like to thank everyone for contributing to the project and paper, including: Blaise Aguera-Arcas, Javier Alberca, Thushan Amarasiriwardena, Lora Aroyo, Martin Baeuml, Leslie Baker, Rachel Bernstein, Taylor Bos, Maarten Bosma, Jonas Bragagnolo, Alena Butryna, Bill Byrne, Chung-Ching Chang, Zhifeng Chen, Dehao Chen, Heng-Tze Cheng, Ed Chi, Aaron Cohen, Eli Collins, Marian Croak, Claire Cui, Andrew Dai, Dipanjan Das, Daniel De Freitas, Jeff Dean, Rajat Dewan, Mark Diaz, Tulsee Doshi, Yu Du, Toju Duke, Doug Eck, Joe Fenton, Noah Fiedel, Christian Frueh, Harish Ganapathy, Saravanan Ganesh, Amin Ghafouri, Zoubin Ghahramani, Kourosh Gharachorloo, Jamie Hall, Erin Hoffman-John, Sissie Hsiao, Yanping Huang, Ben Hutchinson, Daphne Ippolito, Alicia Jin, Thomas Jurdi, Ashwin Kakarla, Nand Kishore, Maxim Krikun, Karthik Krishnamoorthi, Igor Krivokon, Apoorv Kulshreshtha, Ray Kurzweil, Viktoriya Kuzmina, Vivek Kwatra, Matthew Lamm, Quoc Le, Max Lee, Katherine Lee, Hongrae Lee, Josh Lee, Dmitry Lepikhin, YaGuang Li, Yifeng Lu, David Luan, Daphne Luong, Laichee Man, Jianchang (JC) Mao, Yossi Matias, Kathleen Meier-Hellstern, Marcelo Menegali, Muqthar Mohammad,, Muqthar Mohammad, Alejandra Molina, Erica Moreira, Meredith Ringel Morris, Maysam Moussalem, Jiaqi Mu, Tyler Mullen, Tyler Mullen, Eric Ni, Kristen Olson, Alexander Passos, Fernando Pereira, Slav Petrov, Marc Pickett, Roberto Pieraccini, Christian Plagemann, Sahitya Potluri, Vinodkumar Prabhakaran, Andy Pratt, James Qin, Ravi Rajakumar, Adam Roberts, Will Rusch, Renelito Delos Santos, Noam Shazeer, RJ Skerry-Ryan, Grigori Somin, Johnny Soraker, Pranesh Srinivasan, Amarnag Subramanya, Mustafa Suleyman, Romal Thoppilan, Song Wang, Sheng Wang, Chris Wassman, Yuanzhong Xu, Yuanzhong Xu, Ni Yan, Ben Zevenbergen, Vincent Zhao, Huaixiu Steven Zheng, Denny Zhou, Hao Zhou, Yanqi Zhou, and more.