/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This source file is part of the Cangjie project, licensed under Apache-2.0
* with Runtime Library Exception.
*
* See https://cangjie-lang.cn/pages/LICENSE for license information.
*/
// The Cangjie API is in Beta. For details on its capabilities and limitations, please refer to the README file.
package std.unittest.mock
import std.collection.*
import std.sync.*
import std.unittest.mock.internal.*
private const GLOBAL_SESSION_NAME = "Global"
let _FRAMEWORK = MockFramework()
type UnmatchedChains = ArrayList<(StubChain, MatchStatus)>
public class MockFramework {
private let sessions: ArrayList<MockSession> = ArrayList.of(MockSession(GLOBAL_SESSION_NAME, Stateless))
private var recording: Option<RecordingHandler> = None
private let globalLock = ReentrantMutex()
private let objCounter = AtomicUInt64(0)
private let stackTraceCache = HashMap<RawStackTraceElement, StackTraceElement>()
protected static const testCaseSessionPrefix = "test case "
protected static const benchSessionPrefix = "benchmark batch of "
protected static const initSessionPrefix = "initializer of class "
protected static const beforeEachSession = "func beforeEach"
init() {
CallHandler.recordStatic(CallHandlerImpl(objCounter.fetchAdd(1), GLOBAL_SESSION_NAME))
}
/**
* Opens a new session. Sessions form a stacklike structure.
* Sessions are closed in the reverse order they were opened in.
* Mock objects created during a given session are only accessible inside the session or any of it's inner sessions.
* Each session keeps its own invocation log so any verification is performed on calls made inside latest open session.
* Expectations can only be verified whenever the session is closed.
*
* @param name to identify a session for debugging purposes.
* @param sessionKind indicating what stubs are allowed in this session.
*/
public static func openSession(name: String, sessionKind: MockSessionKind): Unit {
return _FRAMEWORK.doOpen(name, sessionKind)
}
/**
* Closes the session. Sessions can only be closed in the same order they were opened in.
*
* Checks for misconfiguration errors such as stubs created but not configured.
* Automatically verifies expectations for all declared stubs iff the session was *Verifiable*. See *MockSessionKind*.
*
* @throws MockFrameworkException if any misconfiguration errors were detected.
* @throws ExpectationFailedException if any of the expectations were not met.
*/
public static func closeSession(): Unit {
return _FRAMEWORK.doClose()
}
protected static func isCurrentSessionForBench(caseName: String): Bool {
_FRAMEWORK.currentSession().name == benchSessionPrefix + caseName
}
func generateUniqueId(): UInt64 {
return objCounter.fetchAdd(1)
}
private func lock<T>(body: () -> T): T {
try {
globalLock.lock()
body()
} finally {
globalLock.unlock()
}
}
private func doOpen(name: String, sessionKind: MockSessionKind): Unit {
lock {
match (currentSession().sessionKind) {
case Stateless =>
sessions.add(MockSession(name, sessionKind))
CallHandler.recordStatic(CallHandlerImpl(0, name))
case _ => illegalInput("Invalid session structure")
}
}
}
static func session<T>(body: (MockSession) -> T): T {
return _FRAMEWORK.lock { body(_FRAMEWORK.currentSession()) }
}
private func doClose(): Unit {
lock {
let lastSessionIndex = match (sessions.size) {
case 0 => internalError("Should always have global session open")
case 1 => illegalInput("Cannot close session: no sessions were opened")
case size => size - 1
}
let sessionToClose = sessions[lastSessionIndex]
sessions.remove(at: lastSessionIndex)
CallHandler.clearStatic()
sessionToClose.checkForMisconfigurationIssues()
sessionToClose.invocationLog.validateLog()
if (let Verifiable <- sessionToClose.sessionKind) {
sessionToClose.verifyExpectations()
}
stackTraceCache.clear()
}
}
private func currentSession(): MockSession {
return sessions.get(sessions.size - 1) ?? internalError("Should always have global session open")
}
static func recordCall<R>(
recordingInfo: RecordingInfo, dummyCall: () -> R
): RecordedCall {
_FRAMEWORK.lock {
let handler = RecordingHandler(recordingInfo)
_FRAMEWORK.recording = Some(handler)
dummyCall()
_FRAMEWORK.recording = None
handler.getRecorded()
}
}
private func getAction(
sc: Scenario, currentSession: MockSession, invocation: Invocation, call: Call
): OnCall {
let invocationLog = currentSession.invocationLog
return match (sc.action) {
case Return(valueFactory) => OnCall.Return(valueFactory())
case Throw(exceptionFactory) => OnCall.Throw(exceptionFactory())
case CallOriginal => OnCall.CallBase
case GetField(description) =>
OnCall.Return(currentSession.fieldStorage.getExplicitFieldValue(description))
case SetField(description) =>
currentSession.fieldStorage.setExplicitFieldValue(description, getSingleArg(call))
OnCall.Return(())
case Fail =>
logAndThrow(invocationLog, invocation,
runtimeUnwantedInteraction(invocation, sc, invocationLog.fullLog())
)
}
}
private func handleCall(call: Call, objId: UInt64, objectCreatedAtSessionName: ?String): OnCall {
let currentSession = currentSession()
let obj = if (objId != 0) {
findObjectById(objId) ??
MockScopeViolation().reportInaccessible(objectCreatedAtSessionName, currentSession.name)
} else {
Option<MockObject>.None
}
if (let Some(recording) <- recording) {
recording.record(call, obj)
return OnCall.ReturnZero
}
let invocation = Invocation(obj, call, LocationImpl.fromTrace())
let unmatchedChains = UnmatchedChains()
let invocationLog = currentSession.invocationLog
for (i in (sessions.size - 1)..=0 : -1) {
let handlingSession = sessions[i]
match (handlingSession.handle(invocation, unmatchedChains)) {
case Unhandled => ()
case PerformScenario(sc) =>
invocationLog.newEntry(LogEntry(invocation, Stub(sc.parentChain)))
return getAction(sc, currentSession, invocation, call)
case InvocationLimitExceeded(chain) =>
logAndThrow(invocationLog, invocation,
invocationLimitExceededReport(invocation, chain, invocationLog.fullLog())
)
}
}
let defaultBehaviour = handleUnstubbed(invocation, invocationLog, unmatchedChains)
invocationLog.newEntry(LogEntry(invocation, Default))
return defaultBehaviour
}
private func handleUnstubbed(invocation: Invocation, log: MutableLog, unmatchedChains: UnmatchedChains): OnCall {
let mockObject = match (invocation.mockObject) {
case Some(o) => o
case _ => return OnCall.CallBase
}
if (let Spy <- mockObject.kind) {
return OnCall.CallBase
}
let funcInfo = invocation.call.funcInfo
if (mockObject.hasMode(SyntheticFields)) {
match {
case funcInfo.isSetter =>
currentSession().fieldStorage.setAutoFieldValue(
AutoFieldId(mockObject, funcInfo.presentableName), getSingleArg(invocation.call)
)
return OnCall.Return(())
case funcInfo.isGetter =>
if (let Some(value) <- currentSession().fieldStorage.getAutoFieldValue(
AutoFieldId(mockObject, funcInfo.presentableName)
)) {
return OnCall.Return(value)
}
if (invocation.canReturnDefaultValue) {
return OnCall.ReturnDefault
}
logAndThrow(log, invocation,
FieldsErrorReport().errorReadingField(invocation)
)
case _ => ()
}
}
if (invocation.canReturnDefaultValue) {
return OnCall.ReturnDefault
}
logAndThrow(log, invocation,
unhandledInvocationReport(invocation, unmatchedChains, log.fullLog())
)
}
private func logAndThrow(log: MutableLog, invocation: Invocation, error: PrettyException): Nothing {
log.newEntry(LogEntry(invocation, Failure(error)))
throw error
}
func handleSync(call: Call, objId: UInt64, objectCreatedAtSessionName: ?String): OnCall {
lock {
handleCall(call, objId, objectCreatedAtSessionName)
}
}
func registerObject(mockObject: MockObject) {
lock {
currentSession().objectRegistry[mockObject.id] = mockObject
}
}
private func findObjectById(id: UInt64): Option<MockObject> {
for (session in sessions) {
if (let Some(obj) <- session.objectRegistry.get(id)) {
return obj
}
}
return None
}
func findObjectIdsByRefs(referencesToFind: Array<Object>): Set<UInt64> {
lock {
doFindObjectIdsByRefs(referencesToFind)
}
}
private func doFindObjectIdsByRefs(referencesToFind: Array<Object>) {
let ids = HashSet<UInt64>()
for (session in sessions) {
for ((id, mockObject) in session.objectRegistry) {
if (referencesToFind |> any { targetRef: Object => refEq(mockObject.ref, targetRef) }) {
ids.add(id)
}
}
}
return ids
}
private func getSingleArg(call: Call): Any {
if (call.args.size != 1) {
internalError(InconsistentState("Setters must have a single argument."))
}
return call.args[0]
}
protected func throwNotAMockObject() {
let recordingInfo = recording?.info ??
internalError(InternalError.InconsistentState("There is no recording call info before actual calling"))
RecordingFailureReport().fail(recordingInfo, NotAMock, UnmockableCallable)
}
func getCachedStackTraceFrame(rawStackTrace: Array<UInt64>): Option<StackTraceElement> {
let frameSize = rawStackTrace.size / 3
for (i in 0..frameSize) {
let rawStackTraceElement = RawStackTraceElement.byIndex(rawStackTrace, i)
if (stackTraceCache.contains(rawStackTraceElement)) {
return stackTraceCache[rawStackTraceElement]
}
}
return Option<StackTraceElement>.None
}
func recordStackTraceFrame(
rawStackTrace: Array<UInt64>, frameIndex: Int64, stackTraceElement: StackTraceElement
) {
stackTraceCache[RawStackTraceElement.byIndex(rawStackTrace, frameIndex)] = stackTraceElement
}
}
extend MockObject {
func hasMode(mode: StubMode) {
for (m in modes) {
match ((m, mode)) {
case (SyntheticFields, SyntheticFields) => return true
case (ReturnsDefaults, ReturnsDefaults) => return true
case _ => ()
}
}
return false
}
}
extend Invocation {
prop isDefaultEnabled: Bool {
get() {
(mockObject?.hasMode(ReturnsDefaults) ?? false) && !call.funcInfo.isSetter
}
}
prop canReturnDefaultValue: Bool {
get() {
isDefaultEnabled && call.funcInfo.hasDefaultValue
}
}
}
class CallHandlerImpl <: CallHandler {
CallHandlerImpl(
private let objId: UInt64,
private let objectCreatedAtSessionName: ?String
) {}
public func onCall(call: Call): OnCall {
_FRAMEWORK.handleSync(call, objId, objectCreatedAtSessionName)
}
public func throwNotAMockObject(): Nothing {
_FRAMEWORK.throwNotAMockObject()
}
}
class MockScopeViolation <: FailureReport {
static let HINT = "Hint: accessing mock objects outside its intended use scope can lead to unstable tests"
func reportInaccessible(createdAtSessionName: ?String, whereSessionName: ?String) {
let message = build {
errorHeader("Illegal mock object access")
line {
text("Mock object")
match (createdAtSessionName) {
case None => ()
case Some(name) =>
text("created in")
userCode(name)
}
text("is no longer accessible")
match (whereSessionName) {
case None => ()
case Some(name) =>
text("in")
userCode(name)
}
}
line {
text(HINT)
}
}
throw MockFrameworkException(message)
}
}
private struct RawStackTraceElement <: Hashable & Equatable<RawStackTraceElement> {
let methodName: UInt64
let fileName: UInt64
let lineNumber: UInt64
static func byIndex(rawStackTrace: Array<UInt64>, frameIndex: Int64) {
return RawStackTraceElement(
rawStackTrace[frameIndex * 3], rawStackTrace[frameIndex * 3 + 1], rawStackTrace[frameIndex * 3 + 2])
}
private init(methodName: UInt64, fileName: UInt64, lineNumber: UInt64) {
this.methodName = methodName
this.fileName = fileName
this.lineNumber = lineNumber
}
public func hashCode(): Int64 {
var hasher = DefaultHasher()
hasher.write(methodName)
hasher.write(fileName)
hasher.write(lineNumber)
return hasher.finish()
}
public operator func ==(that: RawStackTraceElement): Bool {
return this.methodName == that.methodName && this.fileName == that.fileName &&
this.lineNumber == that.lineNumber
}
public operator func !=(that: RawStackTraceElement): Bool {
return this.methodName != that.methodName || this.fileName != that.fileName ||
this.lineNumber != that.lineNumber
}
}