* Generic OAuth flow for MCP servers.
*
* Allows users to authenticate with any OAuth-compatible MCP server
* by providing authorization URL, token URL, and client credentials.
*/
import type { OAuthCallbackFlowOptions } from "@oh-my-pi/pi-ai/utils/oauth/callback-server";
import { OAuthCallbackFlow } from "@oh-my-pi/pi-ai/utils/oauth/callback-server";
import type { OAuthController, OAuthCredentials } from "@oh-my-pi/pi-ai/utils/oauth/types";
const DEFAULT_PORT = 3000;
const CALLBACK_PATH = "/callback";
function isLoopbackHostname(hostname: string): boolean {
return hostname === "localhost" || hostname === "127.0.0.1";
}
function resolveRedirectUri(redirectUri: string | undefined): string | undefined {
const configured = redirectUri;
const trimmed = configured?.trim();
if (!trimmed) return undefined;
if (trimmed !== configured) {
throw new Error("OAuth redirect URI must not include surrounding whitespace");
}
const parsed = new URL(configured);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
throw new Error("OAuth redirect URI must use http or https");
}
return configured;
}
function parseRedirectUri(redirectUri: string | undefined): URL | undefined {
return redirectUri ? new URL(redirectUri) : undefined;
}
function getUriPort(uri: URL): number {
if (uri.port !== "") return Number(uri.port);
return uri.protocol === "https:" ? 443 : 80;
}
function validateRedirectConfig(config: MCPOAuthConfig, redirectUri: string | undefined): void {
const parsed = parseRedirectUri(redirectUri);
if (parsed?.protocol !== "https:" || !isLoopbackHostname(parsed.hostname)) {
return;
}
if (config.callbackPort === undefined) {
throw new Error(
"HTTPS loopback redirect URIs require oauth.callbackPort to point at the local HTTP callback listener behind your TLS terminator",
);
}
if (config.callbackPort === getUriPort(parsed)) {
throw new Error(
"HTTPS loopback redirect URIs cannot reuse the same local port; terminate TLS separately and forward to oauth.callbackPort",
);
}
}
function resolveCallbackPort(callbackPort: number | undefined, redirectUri: string | undefined): number {
if (callbackPort !== undefined) return callbackPort;
const parsed = parseRedirectUri(redirectUri);
if (parsed?.protocol !== "http:" || !isLoopbackHostname(parsed.hostname)) {
return DEFAULT_PORT;
}
const port = getUriPort(parsed);
return Number.isFinite(port) && port > 0 ? port : DEFAULT_PORT;
}
function resolveCallbackPath(callbackPath: string | undefined, redirectUri: string | undefined): string {
const trimmed = callbackPath?.trim();
if (trimmed) return trimmed.startsWith("/") ? trimmed : `/${trimmed}`;
const parsed = parseRedirectUri(redirectUri);
if (parsed?.pathname) return parsed.pathname;
return CALLBACK_PATH;
}
function resolveCallbackHostname(redirectUri: string | undefined): string | undefined {
const parsed = parseRedirectUri(redirectUri);
if (!parsed || !isLoopbackHostname(parsed.hostname)) return undefined;
return parsed.hostname;
}
function resolveCallbackOptions(config: MCPOAuthConfig): OAuthCallbackFlowOptions {
const redirectUri = resolveRedirectUri(config.redirectUri);
validateRedirectConfig(config, redirectUri);
return {
preferredPort: resolveCallbackPort(config.callbackPort, redirectUri),
callbackPath: resolveCallbackPath(config.callbackPath, redirectUri),
callbackHostname: resolveCallbackHostname(redirectUri),
redirectUri,
};
}
export interface MCPOAuthConfig {
authorizationUrl: string;
tokenUrl: string;
clientId?: string;
clientSecret?: string;
scopes?: string;
redirectUri?: string;
callbackPort?: number;
callbackPath?: string;
}
* Generic OAuth flow for MCP servers.
* Supports standard OAuth 2.0 authorization code flow with PKCE.
*/
export class MCPOAuthFlow extends OAuthCallbackFlow {
#resolvedClientId?: string;
#registeredClientSecret?: string;
#codeVerifier?: string;
constructor(
private config: MCPOAuthConfig,
ctrl: OAuthController,
) {
super(ctrl, resolveCallbackOptions(config));
this.#resolvedClientId = this.#resolveClientId(config);
}
* Client id used during the authorization request. Returns the value supplied
* via {@link MCPOAuthConfig.clientId} or, when the server required dynamic
* client registration, the id issued during registration. `undefined` until
* {@link generateAuthUrl} (or {@link login}) has run for a server that needs
* a client id.
*/
get resolvedClientId(): string | undefined {
return this.#resolvedClientId;
}
* Client secret issued by dynamic client registration, if any. Always
* `undefined` for PKCE-only/public clients and when the caller supplies the
* client id via config.
*/
get registeredClientSecret(): string | undefined {
return this.#registeredClientSecret;
}
async generateAuthUrl(state: string, redirectUri: string): Promise<{ url: string; instructions?: string }> {
if (!this.#resolvedClientId) {
await this.#tryRegisterClient(redirectUri);
}
const authUrl = new URL(this.config.authorizationUrl);
const params = authUrl.searchParams;
if (!params.get("response_type")) {
params.set("response_type", "code");
}
const existingClientId = params.get("client_id")?.trim();
if (this.#resolvedClientId && !existingClientId) {
params.set("client_id", this.#resolvedClientId);
}
if (this.config.scopes && !params.get("scope")) {
params.set("scope", this.config.scopes);
}
params.set("redirect_uri", redirectUri);
params.set("state", state);
const codeVerifier = this.#generateCodeVerifier();
const codeChallenge = await this.#generateCodeChallenge(codeVerifier);
params.set("code_challenge", codeChallenge);
params.set("code_challenge_method", "S256");
this.#codeVerifier = codeVerifier;
if (!params.get("client_id")) {
await this.#assertClientIdNotRequired(authUrl.toString());
}
return { url: authUrl.toString() };
}
async exchangeToken(code: string, _state: string, redirectUri: string): Promise<OAuthCredentials> {
const params = new URLSearchParams({
grant_type: "authorization_code",
code,
redirect_uri: redirectUri,
});
if (this.#resolvedClientId) {
params.set("client_id", this.#resolvedClientId);
}
if (this.#codeVerifier) {
params.set("code_verifier", this.#codeVerifier);
}
this.#codeVerifier = undefined;
const clientSecret = this.config.clientSecret ?? this.#registeredClientSecret;
if (clientSecret) {
params.set("client_secret", clientSecret);
}
const response = await fetch(this.config.tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: params.toString(),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Token exchange failed: ${response.status} ${errorText}`);
}
const data = (await response.json()) as {
access_token: string;
refresh_token?: string;
expires_in?: number;
token_type?: string;
};
const expiresIn = data.expires_in ?? 3600;
const expires = Date.now() + expiresIn * 1000;
return {
access: data.access_token,
refresh: data.refresh_token ?? "",
expires,
};
}
* Generate PKCE code verifier (random string).
*/
#generateCodeVerifier(): string {
const bytes = new Uint8Array(32);
crypto.getRandomValues(bytes);
return this.#base64UrlEncode(bytes);
}
* Generate PKCE code challenge from verifier.
*/
async #generateCodeChallenge(verifier: string): Promise<string> {
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const hash = await crypto.subtle.digest("SHA-256", data);
return this.#base64UrlEncode(new Uint8Array(hash));
}
* Base64 URL encode (without padding).
*/
#base64UrlEncode(bytes: Uint8Array): string {
const base64 = btoa(String.fromCharCode(...bytes));
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
}
#resolveClientId(config: MCPOAuthConfig): string | undefined {
const fromConfig = config.clientId?.trim();
if (fromConfig) return fromConfig;
try {
return new URL(config.authorizationUrl).searchParams.get("client_id") ?? undefined;
} catch {
return undefined;
}
}
* Try OAuth dynamic client registration when provider requires a client_id.
*/
async #tryRegisterClient(redirectUri: string): Promise<void> {
const registrationEndpoint = await this.#resolveRegistrationEndpoint();
if (!registrationEndpoint) return;
try {
const response = await fetch(registrationEndpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "application/json",
},
body: JSON.stringify({
client_name: "Codex",
redirect_uris: [redirectUri],
grant_types: ["authorization_code", "refresh_token"],
response_types: ["code"],
token_endpoint_auth_method: "none",
application_type: "native",
}),
});
if (!response.ok) return;
const data = (await response.json()) as {
client_id?: string;
client_secret?: string;
};
if (data.client_id && data.client_id.trim() !== "") {
this.#resolvedClientId = data.client_id;
}
if (data.client_secret && data.client_secret.trim() !== "") {
this.#registeredClientSecret = data.client_secret;
}
} catch {
}
}
async #resolveRegistrationEndpoint(): Promise<string | null> {
const authorizationUrl = new URL(this.config.authorizationUrl);
const rootUrl = new URL("/.well-known/oauth-authorization-server", authorizationUrl.origin).toString();
const endpoint = await this.#tryWellKnownForRegistration(rootUrl);
if (endpoint) return endpoint;
const normalizedPath = authorizationUrl.pathname.replace(/\/$/, "");
const lastSlash = normalizedPath.lastIndexOf("/");
if (lastSlash < 0) return null;
const prefixPath = lastSlash === 0 ? normalizedPath : normalizedPath.slice(0, lastSlash);
const prefixedUrl = new URL(
".well-known/oauth-authorization-server",
`${authorizationUrl.origin}${prefixPath}/`,
).toString();
const prefixedEndpoint = await this.#tryWellKnownForRegistration(prefixedUrl);
if (prefixedEndpoint) return prefixedEndpoint;
const pathfulUrl = new URL(
`/.well-known/oauth-authorization-server${normalizedPath}`,
authorizationUrl.origin,
).toString();
return await this.#tryWellKnownForRegistration(pathfulUrl);
}
async #tryWellKnownForRegistration(wellKnownUrl: string): Promise<string | null> {
try {
const response = await fetch(wellKnownUrl, {
method: "GET",
headers: { Accept: "application/json" },
});
if (!response.ok) return null;
const metadata = (await response.json()) as { registration_endpoint?: string };
if (metadata.registration_endpoint && metadata.registration_endpoint.trim() !== "") {
return metadata.registration_endpoint;
}
} catch {
}
return null;
}
async #assertClientIdNotRequired(authorizationUrl: string): Promise<void> {
try {
const response = await fetch(authorizationUrl, {
method: "GET",
redirect: "manual",
headers: { Accept: "text/plain,text/html,application/json" },
});
if (response.status < 400) return;
const body = await response.text();
if (/client[_-]?id/i.test(body) && /(required|missing|invalid)/i.test(body)) {
throw new Error("OAuth provider requires client_id");
}
} catch (error) {
if (error instanceof Error && /client[_-]?id/i.test(error.message)) {
throw error;
}
}
}
}
* Refresh an MCP OAuth token using the standard refresh_token grant.
* Returns updated credentials; preserves the old refresh token if the server doesn't rotate it.
*/
export async function refreshMCPOAuthToken(
tokenUrl: string,
refreshToken: string,
clientId?: string,
clientSecret?: string,
): Promise<OAuthCredentials> {
const params = new URLSearchParams({
grant_type: "refresh_token",
refresh_token: refreshToken,
});
if (clientId) params.set("client_id", clientId);
if (clientSecret) params.set("client_secret", clientSecret);
const response = await fetch(tokenUrl, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: params.toString(),
});
if (!response.ok) {
const text = await response.text();
throw new Error(`MCP OAuth refresh failed: ${response.status} ${text}`);
}
const data = (await response.json()) as {
access_token: string;
refresh_token?: string;
expires_in?: number;
};
const expiresIn = data.expires_in ?? 3600;
return {
access: data.access_token,
refresh: data.refresh_token ?? refreshToken,
expires: Date.now() + expiresIn * 1000,
};
}