#include "ash/wm/layer_tree_synchronizer.h"
#include <array>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <utility>
#include <vector>
#include "base/check_op.h"
#include "base/numerics/angle_conversions.h"
#include "ui/aura/window.h"
#include "ui/compositor/layer.h"
#include "ui/gfx/geometry/mask_filter_info.h"
#include "ui/gfx/geometry/point_f.h"
#include "ui/gfx/geometry/rect_f.h"
#include "ui/gfx/geometry/rrect_f.h"
#include "ui/gfx/geometry/transform.h"
#include "ui/gfx/geometry/vector2d_f.h"
namespace ash {
namespace {
using Corner = gfx::RRectF::Corner;
constexpr int kExpectedModifiedLayers = 8;
float Square(float n) {
return std::pow(n, 2);
}
struct CircularArc {
static constexpr int kDegree0 = 0;
static constexpr int kDegree90 = 90;
static constexpr int kDegree180 = 180;
static constexpr int kDegree270 = 270;
static constexpr int kDegree360 = 360;
gfx::PointF center;
float radius;
int start_angle;
int end_angle;
bool Intersects(const CircularArc& other) const {
return FindIntersection(other);
}
bool InclusivelyContains(const gfx::PointF point) const {
const gfx::Vector2dF distance_vec = point - center;
if (distance_vec.Length() > radius) {
return false;
}
const float angle_in_radians =
std::atan2(-distance_vec.y(), distance_vec.x());
const int normalized_angle_in_degree =
base::RadToDeg(angle_in_radians) +
(angle_in_radians < kDegree0 ? kDegree360 : kDegree0);
return normalized_angle_in_degree >= start_angle &&
normalized_angle_in_degree <= end_angle;
}
gfx::PointF GetMidPointOnArc() const {
const float mid_angle_in_degree = (start_angle + end_angle) / 2;
const float mid_angle_in_radian = base::DegToRad(mid_angle_in_degree);
return {center.x() + radius * std::cos(mid_angle_in_radian),
center.y() - radius * std::sin(mid_angle_in_radian)};
}
private:
using Points = std::pair<std::vector<gfx::PointF>, bool>;
bool FindIntersection(const CircularArc& other) const {
const Points points = FindIntersectionBetweenTwoCircles(
center, radius, other.center, other.radius);
if (points.second) {
return start_angle >= other.start_angle && end_angle <= other.end_angle;
}
for (const auto& point : points.first) {
if (InclusivelyContains(point) && other.InclusivelyContains(point)) {
return true;
}
}
return false;
}
Points FindIntersectionBetweenTwoCircles(const gfx::PointF& c1,
float r1,
const gfx::PointF& c2,
float r2) const {
static constexpr Points kNoIntersectingPoint = {{}, false};
static constexpr Points kInfiniteIntersectingPoints = {{}, true};
if (c1 == c2 && r1 == r2) {
return kInfiniteIntersectingPoints;
}
const gfx::Vector2dF distance_vec = c2 - c1;
const float d = distance_vec.Length();
if (d > r1 + r2) {
return kNoIntersectingPoint;
}
if (d < std::abs(r1 - r2)) {
return kNoIntersectingPoint;
}
float a = (Square(r1) - Square(r2) + Square(d)) / (2 * Square(d));
float h = std::sqrt(Square(r1) - Square(a)) / d;
float p5_x = c1.x() + a * distance_vec.x();
float p5_y = c1.y() + a * distance_vec.y();
if (h == 0) {
return {{gfx::PointF(p5_x, p5_y)}, false};
}
float p3_x = p5_x - h * distance_vec.y();
float p3_y = p5_y + h * distance_vec.x();
float p4_x = p3_x + h * distance_vec.y();
float p4_y = p3_y - h * distance_vec.x();
return {{gfx::PointF(p3_x, p3_y), gfx::PointF(p4_x, p4_y)},
false};
}
};
CircularArc GetArcForCorner(const gfx::RRectF& rrectf, Corner corner) {
const gfx::RectF bounding_box = rrectf.CornerBoundingRect(corner);
CircularArc corner_arc;
const auto radii = rrectf.GetCornerRadii(corner);
corner_arc.radius = radii.x();
switch (corner) {
case Corner::kUpperRight:
corner_arc.center = bounding_box.bottom_left();
corner_arc.start_angle = CircularArc::kDegree0;
corner_arc.end_angle = CircularArc::kDegree90;
break;
case Corner::kUpperLeft:
corner_arc.center = bounding_box.bottom_right();
corner_arc.start_angle = CircularArc::kDegree90;
corner_arc.end_angle = CircularArc::kDegree180;
break;
case Corner::kLowerLeft:
corner_arc.center = bounding_box.top_right();
corner_arc.start_angle = CircularArc::kDegree180;
corner_arc.end_angle = CircularArc::kDegree270;
break;
case Corner::kLowerRight:
corner_arc.center = bounding_box.origin();
corner_arc.start_angle = CircularArc::kDegree270;
corner_arc.end_angle = CircularArc::kDegree360;
break;
}
return corner_arc;
}
bool CheckCornerContainment(const gfx::PointF& p, const gfx::RRectF& rrectf) {
const gfx::RectF rectf = rrectf.rect();
if (!rectf.InclusiveContains(p)) {
return false;
}
Corner containing_corner;
gfx::PointF canonical_point;
const float x = p.x();
const float y = p.y();
const gfx::Vector2dF lower_left_corner_radii =
rrectf.GetCornerRadii(Corner::kLowerLeft);
const gfx::Vector2dF lower_right_corner_radii =
rrectf.GetCornerRadii(Corner::kLowerRight);
const gfx::Vector2dF upper_left_corner_radii =
rrectf.GetCornerRadii(Corner::kUpperLeft);
const gfx::Vector2dF upper_right_corner_radii =
rrectf.GetCornerRadii(Corner::kUpperRight);
if (x < rectf.x() + upper_left_corner_radii.x() &&
y < rectf.y() + upper_left_corner_radii.y()) {
containing_corner = Corner::kUpperLeft;
canonical_point.SetPoint(x - (rectf.x() + upper_left_corner_radii.x()),
y - (rectf.y() + upper_left_corner_radii.y()));
CHECK_LT(canonical_point.x(), 0);
CHECK_LT(canonical_point.y(), 0);
} else if (x < rectf.x() + lower_left_corner_radii.x() &&
y > rectf.bottom() - lower_left_corner_radii.y()) {
containing_corner = Corner::kLowerLeft;
canonical_point.SetPoint(
x - (rectf.x() + lower_left_corner_radii.x()),
y - (rectf.bottom() - lower_left_corner_radii.y()));
CHECK_LT(canonical_point.x(), 0);
CHECK_GT(canonical_point.y(), 0);
} else if (x > rectf.right() - upper_right_corner_radii.x() &&
y < rectf.y() + upper_right_corner_radii.y()) {
containing_corner = Corner::kUpperRight;
canonical_point.SetPoint(x - (rectf.right() - upper_right_corner_radii.x()),
y - (rectf.y() + upper_right_corner_radii.y()));
CHECK_GT(canonical_point.x(), 0);
CHECK_LT(canonical_point.y(), 0);
} else if (x > rectf.right() - lower_right_corner_radii.x() &&
y > rectf.bottom() - lower_right_corner_radii.y()) {
containing_corner = Corner::kLowerRight;
canonical_point.SetPoint(
x - (rectf.right() - lower_right_corner_radii.x()),
y - (rectf.bottom() - lower_right_corner_radii.y()));
CHECK_GT(canonical_point.x(), 0);
CHECK_GT(canonical_point.y(), 0);
} else {
return true;
}
const gfx::Vector2dF containing_corner_radii =
rrectf.GetCornerRadii(containing_corner);
const float distance =
Square(canonical_point.x()) * Square(containing_corner_radii.y()) +
Square(canonical_point.y()) * Square(containing_corner_radii.x());
return distance <=
Square(containing_corner_radii.x() * containing_corner_radii.y());
}
gfx::PointF GetCornerCoordinates(const gfx::RectF& rectf, Corner corner) {
switch (corner) {
case Corner::kUpperLeft:
return rectf.origin();
case Corner::kUpperRight:
return rectf.top_right();
case Corner::kLowerRight:
return rectf.bottom_right();
case Corner::kLowerLeft:
return rectf.bottom_left();
}
}
bool ShouldOverrideCornerRadius(const gfx::RRectF& rect,
const gfx::RRectF& containing_rect,
Corner corner,
bool consider_curvature) {
const gfx::Vector2dF rect_corner_radii = rect.GetCornerRadii(corner);
const gfx::Vector2dF containing_rect_corner_radii =
containing_rect.GetCornerRadii(corner);
if (rect_corner_radii.IsZero() && containing_rect_corner_radii.IsZero()) {
return false;
}
if (containing_rect_corner_radii.IsZero()) {
return false;
}
const gfx::PointF rect_corner_coordinates =
GetCornerCoordinates(rect.rect(), corner);
if (rect_corner_radii.IsZero()) {
return !CheckCornerContainment(rect_corner_coordinates, containing_rect);
}
if (!consider_curvature) {
const gfx::PointF containing_rect_corner_coordinates =
GetCornerCoordinates(containing_rect.rect(), corner);
return rect_corner_coordinates.IsWithinDistance(
containing_rect_corner_coordinates, 0.01);
}
const CircularArc arc = GetArcForCorner(rect, corner);
const CircularArc other_arc = GetArcForCorner(containing_rect, corner);
return arc.Intersects(other_arc) ||
!CheckCornerContainment(arc.GetMidPointOnArc(), containing_rect);
}
using Corners = base::flat_set<gfx::RRectF::Corner>;
Corners FindCornersToOverrideRadius(const gfx::RRectF& rect,
const gfx::RRectF& containing_rect,
bool consider_curvature) {
static constexpr uint8_t kNumberOfCorners = 4;
static constexpr std::array<Corner, kNumberOfCorners> kCorners{
Corner::kUpperLeft, Corner::kUpperRight, Corner::kLowerRight,
Corner::kLowerLeft};
Corners corners;
if (!containing_rect.HasRoundedCorners() ||
!containing_rect.rect().Contains(rect.rect())) {
return corners;
}
corners.reserve(kNumberOfCorners);
for (auto corner : kCorners) {
if (ShouldOverrideCornerRadius(rect, containing_rect, corner,
consider_curvature)) {
corners.insert(corner);
}
}
return corners;
}
gfx::RRectF ApplyTransform(const gfx::RRectF& bounds,
const gfx::Transform& transform) {
gfx::MaskFilterInfo layer_mask_info(bounds);
layer_mask_info.ApplyTransform(transform);
return layer_mask_info.rounded_corner_bounds();
}
gfx::Transform AccumulateTargetTransform(const ui::Layer* layer,
const gfx::Transform& transform) {
gfx::Transform translation;
translation.Translate(layer->bounds().x(), layer->bounds().y());
gfx::Transform accumulated_transform(transform);
accumulated_transform.PreConcat(translation);
const gfx::Transform& layer_transform = layer->GetTargetTransform();
if (!layer_transform.IsIdentity()) {
accumulated_transform.PreConcat(layer_transform);
}
return accumulated_transform;
}
}
LayerTreeSynchronizerBase::LayerTreeSynchronizerBase(bool restore_tree)
: restore_tree_(restore_tree) {
original_layers_info_.reserve(kExpectedModifiedLayers);
}
LayerTreeSynchronizerBase::~LayerTreeSynchronizerBase() = default;
bool LayerTreeSynchronizerBase::SynchronizeLayerTreeRoundedCorners(
ui::Layer* layer,
const ui::Layer* root_layer,
const gfx::RRectF& reference_bounds,
bool consider_curvature) {
CHECK(root_layer);
CHECK(root_layer->Contains(layer));
if (reference_bounds.IsEmpty()) {
return false;
}
gfx::Transform transform;
layer->GetTargetTransformRelativeTo(root_layer, &transform);
return SynchronizeLayerTreeRoundedCornersImpl(layer, reference_bounds,
transform, consider_curvature);
}
bool LayerTreeSynchronizerBase::SynchronizeLayerTreeRoundedCornersImpl(
ui::Layer* layer,
const gfx::RRectF& reference_bounds,
const gfx::Transform& transform,
bool consider_curvature) {
CHECK(layer);
bool layer_altered = false;
const bool ignore_layer = layer->rounded_corner_radii().IsEmpty() ||
!layer->visible() ||
!transform.Preserves2dAxisAlignment();
if (!ignore_layer) {
const gfx::RRectF layer_rrectf(gfx::RectF(layer->bounds().size()),
layer->rounded_corner_radii());
const gfx::RRectF layer_bounds_in_root =
ApplyTransform(layer_rrectf, transform);
const Corners corners_to_update = FindCornersToOverrideRadius(
layer_bounds_in_root, reference_bounds, consider_curvature);
if (!corners_to_update.empty()) {
const gfx::Transform inverse_transform = transform.GetCheckedInverse();
const auto reference_bounds_in_local =
ApplyTransform(reference_bounds, inverse_transform);
gfx::RoundedCornersF radii = layer->rounded_corner_radii();
radii.Set(
corners_to_update.contains(Corner::kUpperLeft)
? reference_bounds_in_local.GetCornerRadii(Corner::kUpperLeft).x()
: radii.upper_left(),
corners_to_update.contains(Corner::kUpperRight)
? reference_bounds_in_local.GetCornerRadii(Corner::kUpperRight)
.x()
: radii.upper_right(),
corners_to_update.contains(Corner::kLowerRight)
? reference_bounds_in_local.GetCornerRadii(Corner::kLowerRight)
.x()
: radii.lower_right(),
corners_to_update.contains(Corner::kLowerLeft)
? reference_bounds_in_local.GetCornerRadii(Corner::kLowerLeft).x()
: radii.lower_left());
if (radii != layer->rounded_corner_radii()) {
if (restore_tree_ && !original_layers_info_.contains(layer)) {
original_layers_info_.insert({layer,
{layer->rounded_corner_radii(),
layer->is_fast_rounded_corner()}});
}
layer->SetRoundedCornerRadius(radii);
layer->SetIsFastRoundedCorner(!radii.IsEmpty());
layer_altered = true;
}
}
}
bool subtree_altered = false;
for (ui::Layer* child : layer->children()) {
subtree_altered |= SynchronizeLayerTreeRoundedCornersImpl(
child, reference_bounds, AccumulateTargetTransform(child, transform),
consider_curvature);
}
return subtree_altered || layer_altered;
}
void LayerTreeSynchronizerBase::ResetCachedLayerInfo() {
original_layers_info_.clear();
}
void LayerTreeSynchronizerBase::RestoreLayerTree(ui::Layer* layer) {
if (original_layers_info_.empty()) {
return;
}
RestoreLayerTreeImpl(layer);
}
void LayerTreeSynchronizerBase::RestoreLayerTreeImpl(ui::Layer* layer) {
if (original_layers_info_.contains(layer)) {
const auto& info = original_layers_info_.at(layer);
layer->SetRoundedCornerRadius(info.first);
layer->SetIsFastRoundedCorner(info.second);
}
for (ui::Layer* child : layer->children()) {
RestoreLayerTreeImpl(child);
}
}
LayerTreeSynchronizer::LayerTreeSynchronizer(bool restore_tree)
: LayerTreeSynchronizerBase(restore_tree) {}
LayerTreeSynchronizer::~LayerTreeSynchronizer() = default;
void LayerTreeSynchronizer::SynchronizeRoundedCorners(
ui::Layer* layer,
const ui::Layer* root_layer,
const gfx::RRectF& reference_bounds) {
const bool altered = SynchronizeLayerTreeRoundedCorners(
layer, root_layer, reference_bounds, true);
if (altered && !altered_layer_observation_.IsObservingSource(layer)) {
altered_layer_observation_.Observe(layer);
}
}
void LayerTreeSynchronizer::Restore() {
if (auto* window = altered_layer_observation_.GetSource()) {
RestoreLayerTree(window);
}
ResetCachedLayerInfo();
altered_layer_observation_.Reset();
}
void LayerTreeSynchronizer::LayerDestroyed(ui::Layer* layer) {
altered_layer_observation_.Reset();
}
WindowTreeSynchronizer::WindowTreeSynchronizer(bool restore_tree)
: LayerTreeSynchronizerBase(restore_tree) {}
WindowTreeSynchronizer::~WindowTreeSynchronizer() = default;
void WindowTreeSynchronizer::SynchronizeRoundedCorners(
aura::Window* window,
const aura::Window* root_window,
const gfx::RRectF& reference_bounds,
bool consider_curvature,
TransientTreeIgnorePredicate ignore_predicate) {
CHECK(root_window->Contains(window));
for (auto* window_iter : GetTransientTreeIterator(window, ignore_predicate)) {
const bool altered = SynchronizeLayerTreeRoundedCorners(
window_iter->layer(), root_window->layer(), reference_bounds,
consider_curvature);
if (altered &&
!altered_window_observations_.IsObservingSource(window_iter)) {
altered_window_observations_.AddObservation(window_iter);
}
}
}
void WindowTreeSynchronizer::Restore() {
for (aura::Window* window : altered_window_observations_.sources()) {
RestoreLayerTree(window->layer());
}
ResetCachedLayerInfo();
altered_window_observations_.RemoveAllObservations();
}
void WindowTreeSynchronizer::OnWindowDestroying(aura::Window* window) {
altered_window_observations_.RemoveObservation(window);
}
}