#ifndef EXTENSIONS_BROWSER_API_WEB_REQUEST_WEB_REQUEST_API_H_
#define EXTENSIONS_BROWSER_API_WEB_REQUEST_WEB_REQUEST_API_H_
#include <stdint.h>
#include <list>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "base/containers/unique_ptr_adapters.h"
#include "base/feature_list.h"
#include "base/gtest_prod_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/no_destructor.h"
#include "base/strings/string_util.h"
#include "base/time/time.h"
#include "base/values.h"
#include "content/public/browser/content_browser_client.h"
#include "content/public/browser/global_request_id.h"
#include "extensions/browser/api/declarative/rules_registry.h"
#include "extensions/browser/api/declarative_webrequest/request_stage.h"
#include "extensions/browser/api/web_request/web_request_api_helpers.h"
#include "extensions/browser/api/web_request/web_request_permissions.h"
#include "extensions/browser/browser_context_keyed_api_factory.h"
#include "extensions/browser/event_router.h"
#include "extensions/browser/extension_api_frame_id_map.h"
#include "extensions/browser/extension_function.h"
#include "extensions/browser/extension_registry_observer.h"
#include "extensions/common/url_pattern_set.h"
#include "ipc/ipc_sender.h"
#include "net/base/auth.h"
#include "net/base/completion_once_callback.h"
#include "net/http/http_request_headers.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
#include "services/network/public/mojom/url_loader_factory.mojom.h"
#include "services/network/public/mojom/websocket.mojom.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
class ExtensionWebRequestTimeTracker;
class GURL;
namespace content {
class BrowserContext;
class RenderFrameHost;
}
namespace net {
class AuthChallengeInfo;
class AuthCredentials;
class HttpRequestHeaders;
class HttpResponseHeaders;
class SiteForCookies;
}
namespace extensions {
enum class WebRequestResourceType : uint8_t;
class WebRequestEventDetails;
struct WebRequestInfo;
class WebRequestRulesRegistry;
class WebRequestAPI : public BrowserContextKeyedAPI,
public EventRouter::Observer,
public ExtensionRegistryObserver {
public:
using AuthRequestCallback = base::OnceCallback<void(
const absl::optional<net::AuthCredentials>& credentials,
bool should_cancel)>;
class Proxy {
public:
virtual ~Proxy() {}
virtual void HandleAuthRequest(
const net::AuthChallengeInfo& auth_info,
scoped_refptr<net::HttpResponseHeaders> response_headers,
int32_t request_id,
AuthRequestCallback callback);
};
class ProxySet {
public:
ProxySet();
ProxySet(const ProxySet&) = delete;
ProxySet& operator=(const ProxySet&) = delete;
~ProxySet();
void AddProxy(std::unique_ptr<Proxy> proxy);
void RemoveProxy(Proxy* proxy);
void AssociateProxyWithRequestId(Proxy* proxy,
const content::GlobalRequestID& id);
void DisassociateProxyWithRequestId(Proxy* proxy,
const content::GlobalRequestID& id);
Proxy* GetProxyFromRequestId(const content::GlobalRequestID& id);
void MaybeProxyAuthRequest(
const net::AuthChallengeInfo& auth_info,
scoped_refptr<net::HttpResponseHeaders> response_headers,
const content::GlobalRequestID& request_id,
AuthRequestCallback callback);
private:
std::set<std::unique_ptr<Proxy>, base::UniquePtrComparator> proxies_;
std::map<content::GlobalRequestID, Proxy*> request_id_to_proxy_map_;
std::map<Proxy*, std::set<content::GlobalRequestID>>
proxy_to_request_id_map_;
};
class RequestIDGenerator {
public:
RequestIDGenerator();
RequestIDGenerator(const RequestIDGenerator&) = delete;
RequestIDGenerator& operator=(const RequestIDGenerator&) = delete;
~RequestIDGenerator();
int64_t Generate(int32_t routing_id, int32_t network_service_request_id);
void SaveID(int32_t routing_id,
int32_t network_service_request_id,
uint64_t request_id);
private:
int64_t id_ = 0;
std::map<std::pair<int32_t, int32_t>, uint64_t> saved_id_map_;
};
explicit WebRequestAPI(content::BrowserContext* context);
WebRequestAPI(const WebRequestAPI&) = delete;
WebRequestAPI& operator=(const WebRequestAPI&) = delete;
~WebRequestAPI() override;
static BrowserContextKeyedAPIFactory<WebRequestAPI>* GetFactoryInstance();
void Shutdown() override;
void OnListenerRemoved(const EventListenerInfo& details) override;
bool MaybeProxyURLLoaderFactory(
content::BrowserContext* browser_context,
content::RenderFrameHost* frame,
int render_process_id,
content::ContentBrowserClient::URLLoaderFactoryType type,
absl::optional<int64_t> navigation_id,
ukm::SourceIdObj ukm_source_id,
mojo::PendingReceiver<network::mojom::URLLoaderFactory>* factory_receiver,
mojo::PendingRemote<network::mojom::TrustedURLLoaderHeaderClient>*
header_client,
const url::Origin& request_initiator = url::Origin());
bool MaybeProxyAuthRequest(
content::BrowserContext* browser_context,
const net::AuthChallengeInfo& auth_info,
scoped_refptr<net::HttpResponseHeaders> response_headers,
const content::GlobalRequestID& request_id,
bool is_main_frame,
AuthRequestCallback callback);
void ProxyWebSocket(
content::RenderFrameHost* frame,
content::ContentBrowserClient::WebSocketFactory factory,
const GURL& url,
const net::SiteForCookies& site_for_cookies,
const absl::optional<std::string>& user_agent,
mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
handshake_client);
void ProxyWebTransport(
content::RenderProcessHost& render_process_host,
int frame_routing_id,
const GURL& url,
const url::Origin& initiator_origin,
mojo::PendingRemote<network::mojom::WebTransportHandshakeClient>
handshake_client,
content::ContentBrowserClient::WillCreateWebTransportCallback callback);
void ForceProxyForTesting();
bool MayHaveProxies() const;
bool MayHaveWebsocketProxiesForExtensionTelemetry() const;
bool HasExtraHeadersListenerForTesting();
private:
friend class BrowserContextKeyedAPIFactory<WebRequestAPI>;
static const char* service_name() { return "WebRequestAPI"; }
static const bool kServiceRedirectedInIncognito = true;
static const bool kServiceIsNULLWhileTesting = true;
void UpdateMayHaveProxies();
void OnExtensionLoaded(content::BrowserContext* browser_context,
const Extension* extension) override;
void OnExtensionUnloaded(content::BrowserContext* browser_context,
const Extension* extension,
UnloadedExtensionReason reason) override;
int web_request_extension_count_ = 0;
const raw_ptr<content::BrowserContext, DanglingUntriaged> browser_context_;
RequestIDGenerator request_id_generator_;
std::unique_ptr<ProxySet> proxies_;
bool may_have_proxies_;
};
class ExtensionWebRequestEventRouter {
public:
struct BlockedRequest;
using BrowserContextID = std::uintptr_t;
static BrowserContextID GetBrowserContextID(
content::BrowserContext* browser_context) {
return reinterpret_cast<BrowserContextID>(
static_cast<void*>(browser_context));
}
enum EventTypes {
kInvalidEvent = 0,
kOnBeforeRequest = 1 << 0,
kOnBeforeSendHeaders = 1 << 1,
kOnSendHeaders = 1 << 2,
kOnHeadersReceived = 1 << 3,
kOnBeforeRedirect = 1 << 4,
kOnAuthRequired = 1 << 5,
kOnResponseStarted = 1 << 6,
kOnErrorOccurred = 1 << 7,
kOnCompleted = 1 << 8,
};
struct RequestFilter {
RequestFilter();
~RequestFilter();
RequestFilter(const RequestFilter&) = delete;
RequestFilter& operator=(const RequestFilter&) = delete;
RequestFilter(RequestFilter&& other);
RequestFilter& operator=(RequestFilter&& other);
bool InitFromValue(const base::Value::Dict& value, std::string* error);
extensions::URLPatternSet urls;
std::vector<WebRequestResourceType> types;
int tab_id;
int window_id;
};
struct EventResponse {
EventResponse(const std::string& extension_id,
const base::Time& extension_install_time);
EventResponse(const EventResponse&) = delete;
EventResponse& operator=(const EventResponse&) = delete;
~EventResponse();
std::string extension_id;
base::Time extension_install_time;
bool cancel;
GURL new_url;
std::unique_ptr<net::HttpRequestHeaders> request_headers;
std::unique_ptr<extension_web_request_api_helpers::ResponseHeaders>
response_headers;
absl::optional<net::AuthCredentials> auth_credentials;
};
enum class AuthRequiredResponse {
AUTH_REQUIRED_RESPONSE_NO_ACTION,
AUTH_REQUIRED_RESPONSE_SET_AUTH,
AUTH_REQUIRED_RESPONSE_CANCEL_AUTH,
AUTH_REQUIRED_RESPONSE_IO_PENDING,
};
using AuthCallback = base::OnceCallback<void(AuthRequiredResponse)>;
static ExtensionWebRequestEventRouter* GetInstance();
void RegisterRulesRegistry(
content::BrowserContext* browser_context,
int rules_registry_id,
scoped_refptr<extensions::WebRequestRulesRegistry> rules_registry);
int OnBeforeRequest(content::BrowserContext* browser_context,
WebRequestInfo* request,
net::CompletionOnceCallback callback,
GURL* new_url,
bool* should_collapse_initiator);
using BeforeSendHeadersCallback =
base::OnceCallback<void(const std::set<std::string>& removed_headers,
const std::set<std::string>& set_headers,
int error_code)>;
int OnBeforeSendHeaders(content::BrowserContext* browser_context,
const WebRequestInfo* request,
BeforeSendHeadersCallback callback,
net::HttpRequestHeaders* headers);
void OnSendHeaders(content::BrowserContext* browser_context,
const WebRequestInfo* request,
const net::HttpRequestHeaders& headers);
int OnHeadersReceived(
content::BrowserContext* browser_context,
const WebRequestInfo* request,
net::CompletionOnceCallback callback,
const net::HttpResponseHeaders* original_response_headers,
scoped_refptr<net::HttpResponseHeaders>* override_response_headers,
GURL* preserve_fragment_on_redirect_url);
AuthRequiredResponse OnAuthRequired(content::BrowserContext* browser_context,
const WebRequestInfo* request,
const net::AuthChallengeInfo& auth_info,
AuthCallback callback,
net::AuthCredentials* credentials);
void OnBeforeRedirect(content::BrowserContext* browser_context,
const WebRequestInfo* request,
const GURL& new_location);
void OnResponseStarted(content::BrowserContext* browser_context,
const WebRequestInfo* request,
int net_error);
void OnCompleted(content::BrowserContext* browser_context,
const WebRequestInfo* request,
int net_error);
void OnErrorOccurred(content::BrowserContext* browser_context,
const WebRequestInfo* request,
bool started,
int net_error);
void OnRequestWillBeDestroyed(content::BrowserContext* browser_context,
const WebRequestInfo* request);
void OnEventHandled(content::BrowserContext* browser_context,
const std::string& extension_id,
const std::string& event_name,
const std::string& sub_event_name,
uint64_t request_id,
int render_process_id,
int web_view_instance_id,
int worker_thread_id,
int64_t service_worker_version_id,
EventResponse* response);
bool AddEventListener(content::BrowserContext* browser_context,
const std::string& extension_id,
const std::string& extension_name,
events::HistogramValue histogram_value,
const std::string& event_name,
const std::string& sub_event_name,
RequestFilter filter,
int extra_info_spec,
int render_process_id,
int web_view_instance_id,
int worker_thread_id,
int64_t service_worker_version_id);
void RemoveWebViewEventListeners(content::BrowserContext* browser_context,
int render_process_id,
int web_view_instance_id);
void OnOTRBrowserContextCreated(
content::BrowserContext* original_browser_context,
content::BrowserContext* otr_browser_context);
void OnOTRBrowserContextDestroyed(
content::BrowserContext* original_browser_context,
content::BrowserContext* otr_browser_context);
void AddCallbackForPageLoad(base::OnceClosure callback);
bool HasExtraHeadersListenerForRequest(
content::BrowserContext* browser_context,
const WebRequestInfo* request);
bool HasAnyExtraHeadersListener(content::BrowserContext* browser_context);
void IncrementExtraHeadersListenerCount(
content::BrowserContext* browser_context);
void DecrementExtraHeadersListenerCount(
content::BrowserContext* browser_context);
void OnBrowserContextShutdown(content::BrowserContext* browser_context);
size_t GetListenerCountForTesting(content::BrowserContext* browser_context,
const std::string& event_name);
size_t GetInactiveListenerCountForTesting(
content::BrowserContext* browser_context,
const std::string& event_name);
private:
friend class WebRequestAPI;
friend class base::NoDestructor<ExtensionWebRequestEventRouter>;
FRIEND_TEST_ALL_PREFIXES(ExtensionWebRequestTest, AddAndRemoveListeners);
FRIEND_TEST_ALL_PREFIXES(ExtensionWebRequestTest, BrowserContextShutdown);
struct EventListener {
struct ID {
ID(content::BrowserContext* browser_context,
const std::string& extension_id,
const std::string& sub_event_name,
int render_process_id,
int web_view_instance_id,
int worker_thread_id,
int64_t service_worker_version_id);
ID(const ID& source);
bool operator==(const ID& that) const;
raw_ptr<content::BrowserContext> browser_context;
std::string extension_id;
std::string sub_event_name;
int render_process_id;
int web_view_instance_id;
int worker_thread_id;
int64_t service_worker_version_id;
};
EventListener(ID id);
EventListener(const EventListener&) = delete;
EventListener& operator=(const EventListener&) = delete;
~EventListener();
ID id;
std::string extension_name;
events::HistogramValue histogram_value = events::UNKNOWN;
RequestFilter filter;
int extra_info_spec = 0;
std::unordered_set<uint64_t> blocked_requests;
};
using RawListeners = std::vector<EventListener*>;
using ListenerIDs = std::vector<EventListener::ID>;
using Listeners = std::vector<std::unique_ptr<EventListener>>;
using ListenerMap = std::map<std::string, Listeners>;
struct BrowserContextData {
BrowserContextData();
BrowserContextData(BrowserContextData&&);
~BrowserContextData();
ListenerMap active_listeners;
ListenerMap inactive_listeners;
int extra_headers_count = 0;
raw_ptr<content::BrowserContext> cross_context = nullptr;
};
using DataMap = std::map<BrowserContextID, BrowserContextData>;
using BlockedRequestMap = std::map<uint64_t, BlockedRequest>;
using SignaledRequestMap = std::map<uint64_t, int>;
using CallbacksForPageLoad = std::list<base::OnceClosure>;
enum class ListenerUpdateType {
kRemove,
kDeactivate,
};
ExtensionWebRequestEventRouter();
ExtensionWebRequestEventRouter(const ExtensionWebRequestEventRouter&) =
delete;
ExtensionWebRequestEventRouter& operator=(
const ExtensionWebRequestEventRouter&) = delete;
~ExtensionWebRequestEventRouter() = delete;
EventListener* FindEventListener(const EventListener::ID& id);
EventListener* FindEventListenerInContainer(const EventListener::ID& id,
Listeners& listeners);
void UpdateActiveListener(ListenerUpdateType update_type,
BrowserContextID browser_context_id,
const ExtensionId& extension_id,
const std::string& sub_event_name,
int worker_thread_id,
int64_t service_worker_version_id);
void RemoveLazyListener(content::BrowserContext* original_context,
const ExtensionId& extension_id,
const std::string& sub_event_name);
static std::unique_ptr<EventListener> RemoveMatchingListener(
Listeners& listeners,
const ExtensionId& extension_id,
const std::string& sub_event_name,
absl::optional<int> worker_thread_id,
absl::optional<int64_t> service_worker_version_id,
BrowserContextID browser_context_id);
void CleanUpForListener(const EventListener& listener,
ListenerUpdateType removal_type);
void ClearPendingCallbacks(const WebRequestInfo& request);
bool DispatchEvent(content::BrowserContext* browser_context,
const WebRequestInfo* request,
const RawListeners& listener_ids,
std::unique_ptr<WebRequestEventDetails> event_details);
void DispatchEventToListeners(
content::BrowserContext* browser_context,
std::unique_ptr<ListenerIDs> listener_ids,
uint64_t request_id,
std::unique_ptr<WebRequestEventDetails> event_details);
RawListeners GetMatchingListeners(content::BrowserContext* browser_context,
const std::string& event_name,
const WebRequestInfo* request,
int* extra_info_spec);
static bool ListenerMatchesRequest(const EventListener& listener,
const WebRequestInfo& request,
content::BrowserContext& browser_context,
bool is_request_from_extension,
bool crosses_incognito);
static void GetMatchingListenersForRequest(
const Listeners& listeners,
const WebRequestInfo& request,
content::BrowserContext& browser_context,
bool is_request_from_extension,
bool crosses_incognito,
RawListeners* listeners_out,
int* extra_info_spec_out);
void DecrementBlockCount(content::BrowserContext* browser_context,
const std::string& extension_id,
const std::string& event_name,
uint64_t request_id,
EventResponse* response,
int extra_info_spec);
int ExecuteDeltas(content::BrowserContext* browser_context,
const WebRequestInfo* request,
bool call_callback);
bool ProcessDeclarativeRules(
content::BrowserContext* browser_context,
const std::string& event_name,
const WebRequestInfo* request,
extensions::RequestStage request_stage,
const net::HttpResponseHeaders* filtered_response_headers);
void SendMessages(content::BrowserContext* browser_context,
const BlockedRequest& blocked_request);
void OnRulesRegistryReady(content::BrowserContext* browser_context,
const std::string& event_name,
uint64_t request_id,
extensions::RequestStage request_stage);
bool GetAndSetSignaled(uint64_t request_id, EventTypes event_type);
void ClearSignaled(uint64_t request_id, EventTypes event_type);
bool IsPageLoad(const WebRequestInfo& request) const;
void NotifyPageLoad();
content::BrowserContext* GetCrossBrowserContext(
content::BrowserContext* browser_context) const;
bool WasSignaled(const WebRequestInfo& request) const;
bool HasAnyExtraHeadersListenerImpl(content::BrowserContext* browser_context);
DataMap data_;
BlockedRequestMap blocked_requests_;
SignaledRequestMap signaled_requests_;
std::unique_ptr<ExtensionWebRequestTimeTracker> request_time_tracker_;
CallbacksForPageLoad callbacks_for_page_load_;
typedef std::pair<BrowserContextID, int> RulesRegistryKey;
std::map<RulesRegistryKey,
scoped_refptr<extensions::WebRequestRulesRegistry> > rules_registries_;
};
class WebRequestInternalFunction : public ExtensionFunction {
public:
WebRequestInternalFunction() {}
protected:
~WebRequestInternalFunction() override {}
const std::string& extension_id_safe() const {
return extension() ? extension_id() : base::EmptyString();
}
};
class WebRequestInternalAddEventListenerFunction
: public WebRequestInternalFunction {
public:
DECLARE_EXTENSION_FUNCTION("webRequestInternal.addEventListener",
WEBREQUESTINTERNAL_ADDEVENTLISTENER)
protected:
~WebRequestInternalAddEventListenerFunction() override {}
ResponseAction Run() override;
};
class WebRequestInternalEventHandledFunction
: public WebRequestInternalFunction {
public:
DECLARE_EXTENSION_FUNCTION("webRequestInternal.eventHandled",
WEBREQUESTINTERNAL_EVENTHANDLED)
protected:
~WebRequestInternalEventHandledFunction() override {}
private:
void OnError(
const std::string& event_name,
const std::string& sub_event_name,
uint64_t request_id,
int render_process_id,
int web_view_instance_id,
std::unique_ptr<ExtensionWebRequestEventRouter::EventResponse> response);
ResponseAction Run() override;
};
class WebRequestHandlerBehaviorChangedFunction
: public WebRequestInternalFunction {
public:
DECLARE_EXTENSION_FUNCTION("webRequest.handlerBehaviorChanged",
WEBREQUEST_HANDLERBEHAVIORCHANGED)
protected:
~WebRequestHandlerBehaviorChangedFunction() override {}
void GetQuotaLimitHeuristics(
extensions::QuotaLimitHeuristics* heuristics) const override;
void OnQuotaExceeded(std::string error) override;
ResponseAction Run() override;
};
}
#endif